diff --git a/CHANGELOG.md b/CHANGELOG.md index 14de0a7..5a3c927 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ # Change Log All notable changes to this project will be documented in this file. + +## 2.2.0 - 2026-03 + ### Runner + - allowing for custom table name (has priority before class name) + - added options to add filter on dependencies and target table based on column-value pairs + - target table can now selectively write based on secondary virtual partitions + ### Common + - table reader can also filter based on given column-values pairs + ## 2.1.4 - 2026-02 ### Common - table reader optimization diff --git a/README.md b/README.md index 294522f..d78f887 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,44 @@ This library currently contains: Runner is the orchestrator and scheduler of Rialto. It can be used to execute any [job](#jobs) but is primarily designed to execute feature [maker](#maker) jobs. The core of runner is execution of a Transformation class that can be extended for any purpose and the execution configuration that defines the handling of i/o, scheduling, dependencies and reporting. +Runner operates on the assumption that your Databricks tables contain a date column (partition column) that indicates the date of data arrival. This enables: + +1. **Time-aware computation**: Run operations on dated tables while managing time-related dependencies automatically +2. **Retrospective simulation**: Run computations as they would have occurred on a specific date by setting `run_date` - ensuring data newer than that date is never used +3. **Dependency tracking**: Automatically verify that input data meets required freshness constraints relative to the run date +4. **Automatic completion detection**: Skip computations when output data already exists (configurable with `rerun` parameter) + +#### Scheduling and Execution + +Runner uses a schedule-based approach: +* Define a schedule (e.g., weekly on day 2, monthly on day 6) in the configuration +* Specify a watch period (how far back to look for missing runs) +* Runner finds all scheduled run dates within the watch period and executes missing ones +* For each date, dependencies are checked and target existence is verified before execution + +#### Data Flow + +For each pipeline execution: +1. **Dependency verification**: Check that all required input tables have data within specified time intervals +2. **Transformation execution**: Run your transformation to produce a Spark DataFrame +3. **Automatic enrichment**: Runner adds `INFORMATION_DATE` (the run date) and `VERSION` (package version) columns +4. **Partitioned write**: Data is written to Databricks with partitioning configuration +5. **Reporting**: Optional email notifications on failure and run information stored to tracking table + +#### Dependency Tracking + +Runner's dependency tracking ensures that all required input data is available before executing a pipeline. For each dependency: + +* **Date-based checking**: Runner looks for data in the dependency table's date column +* **Interval calculation**: The required date is calculated by subtracting the dependency's interval from the run date +* **Existence verification**: Runner checks if data exists for the calculated date (and within any specified filters) +* **Missing data handling**: If required data is missing, Runner raises an error for that specific pipeline/date but continues executing other pipelines and dates in the queue + +**Example:** If you're running a pipeline on 2024-01-15 with a dependency that has a 7-day interval: +* Runner checks if the dependency table has data for 2024-01-08 (15 days - 7 days) +* If the dependency has `filters: {VERSION: "v2"}`, it specifically checks for data where VERSION='v2' +* If data exists, the pipeline proceeds; otherwise, an error is raised for this specific execution, but other scheduled runs continue + ### Transformation For the details on the interface see the [implementation](rialto/runner/transformation.py) Inside the transformation you have access to a [TableReader](#common), date of running, and if provided to Runner, a live spark session and [metadata manager](#metadata). @@ -64,6 +102,13 @@ Transformations are not included in the runner itself, it imports them dynamical ### Configuration +Runner is supplied with a run configuration that defines the computations it will execute. In each pipeline configuration you define: +- **Module**: Python transformation class to execute +- **Schedule**: When to run (daily/weekly/monthly) and on which day +- **Dependencies**: Input tables to check, with required freshness intervals and optional filters +- **Target**: Output location, partitioning strategy, and optional completion filters + + ```yaml runner: watched_period_units: "months" # unit of default run period @@ -129,9 +174,19 @@ pipelines: # a list of pipelines to run interval: units: "days" value: 6 + filters: + dep_column_name1: "value1" + dep_column_name2: "value2" target: target_schema: catalog.schema # schema where tables will be created, must exist target_partition_column: INFORMATION_DATE # date to partition new tables on + secondary_partition_columns: # optional list of secondary partitions to ensure partial-overwrite of the target table based on generated data for these partitions + - column_name1 + - column_name2 + rerun_filters: # optional filters to avoid reruning already generated data for secondary partitionons, if secondary partition values are dynamically generated at runtime, leave this empty but the job will always rerun + column_name1: 42 + column_name2: "some_value" + custom_name: "custom_table_name" # optional custom table name, if not provided, the table name will be the same as pipeline name ``` The configuration can be dynamically overridden by providing a dictionary of overrides to the runner. All overrides must adhere to configurations schema, with pipeline.extras section available for custom schema. @@ -189,6 +244,13 @@ overrides={"runner.watched_period_value": 4, } ``` +### Multiple partitions +Rialto runner and TableReader can handle multiple "partitions", however we only use one primary physical partitions and treat selected columns as other partitions. +When wanting to write to a selected secondary partition/s, you can specify them in the configuration file as **secondary_partition_columns** and provide values for these columns in **rerun_filters**. This way, the runner will only rerun for the data that matches the filters, and leave the rest of the data intact. +You can use env variables to set these filters if the values are available before the job run. If not, the job can be setup without these filters, however by defining secondary target partitions, the job will always rerun because it can't determine whether its supposed to run. + +You can also take advantage of these **filters** options in dependency configuration to ensure the right data is available. + ## 2.2 - maker The purpose of (feature) maker is to simplify feature creation, allow for consistent feature implementation that is standardized and easy to test. diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py index 16a6b38..228d59b 100644 --- a/rialto/common/table_reader.py +++ b/rialto/common/table_reader.py @@ -16,7 +16,7 @@ import abc import datetime -from typing import Optional +from typing import Dict, Optional import pyspark.sql.functions as F from pyspark.sql import DataFrame, SparkSession @@ -37,6 +37,7 @@ def get_latest( date_column: str, date_until: Optional[datetime.date] = None, uppercase_columns: bool = False, + filters: Optional[Dict[str, str]] = None, ) -> DataFrame: """ Get latest available date partition of the table until specified date @@ -44,6 +45,7 @@ def get_latest( :param table: input table path :param date_until: Optional until date (inclusive) :param uppercase_columns: Option to refactor all column names to uppercase + :param filters: Optional dict of column filters to apply before finding latest date :return: Dataframe """ raise NotImplementedError @@ -56,6 +58,7 @@ def get_table( date_from: Optional[datetime.date] = None, date_to: Optional[datetime.date] = None, uppercase_columns: bool = False, + filters: Optional[Dict[str, str]] = None, ) -> DataFrame: """ Get a whole table or a slice by selected dates @@ -96,12 +99,16 @@ def _get_latest_available_date(self, df: DataFrame, date_col: str, until: Option df = df.select(F.max(date_col)).alias("latest") return df.head()[0] + def _get_raw_data(self, table: str) -> DataFrame: + return self.spark.read.table(table) + def get_latest( self, table: str, date_column: str, date_until: Optional[datetime.date] = None, uppercase_columns: bool = False, + filters: Optional[Dict[str, str]] = None, ) -> DataFrame: """ Get latest available date partition of the table until specified date @@ -110,9 +117,14 @@ def get_latest( :param date_until: Optional until date (inclusive) :param date_column: column to filter dates on, takes highest priority :param uppercase_columns: Option to refactor all column names to uppercase + :param filters: Optional dict of column filters to apply before finding latest date :return: Dataframe """ - df = self.spark.read.table(table) + df = self._get_raw_data(table) + + if filters: + for col, val in filters.items(): + df = df.filter(df[col] == val) selected_date = self._get_latest_available_date(df, date_column, date_until) df = df.filter(F.col(date_column) == selected_date) @@ -128,6 +140,7 @@ def get_table( date_from: Optional[datetime.date] = None, date_to: Optional[datetime.date] = None, uppercase_columns: bool = False, + filters: Optional[Dict] = None, ) -> DataFrame: """ Get a whole table or a slice by selected dates @@ -139,7 +152,11 @@ def get_table( :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ - df = self.spark.read.table(table) + df = self._get_raw_data(table) + + if filters: + for col, val in filters.items(): + df = df.filter(df[col] == val) if date_from: df = df.filter(F.col(date_column) >= date_from) diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index cc7c60b..7978ac5 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -44,6 +44,7 @@ class DependencyConfig(BaseConfig): name: Optional[str] = None date_col: str interval: IntervalConfig + filters: Optional[Dict] = None class ModuleConfig(BaseConfig): @@ -69,6 +70,9 @@ class RunnerConfig(BaseConfig): class TargetConfig(BaseConfig): target_schema: str target_partition_column: str + secondary_partition_columns: Optional[List[str]] = None + rerun_filters: Optional[Dict] = None + custom_name: Optional[str] = None class MetadataManagerConfig(BaseConfig): diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 2807a77..384998f 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -18,7 +18,6 @@ from datetime import date from typing import Dict, List, Tuple -import pyspark.sql.functions as F from loguru import logger from pyspark.sql import DataFrame, SparkSession @@ -97,7 +96,7 @@ def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineC return df - def _check_written(self, info_date: date, table: Table) -> int: + def _check_written(self, info_date: date, table: Table, df: DataFrame, pipeline: PipelineConfig) -> int: """ Check if there are records written for given date @@ -105,11 +104,23 @@ def _check_written(self, info_date: date, table: Table) -> int: :param table: target table object :return: number of records """ - df = self.spark.read.table(table.get_table_path()) - df = df.filter(F.col(table.partition) == info_date) + filters = {} + if pipeline.target.rerun_filters is not None: + filters = pipeline.target.rerun_filters + else: + if table.secondary_partitions: + row = df.select(*table.secondary_partitions).distinct().collect()[0] + for c in table.secondary_partitions: + val = row[0][c] + filters[c] = val + + df = self.reader.get_table( + table.get_table_path(), date_column=table.partition, date_from=info_date, date_to=info_date, filters=filters + ) + return df.count() - def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bool]: + def check_dates_have_data(self, table: Table, dates: List[date], target_filters: Dict = None) -> List[bool]: """ For given list of dates, check if there is a matching partition for each @@ -118,8 +129,21 @@ def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bo :return: list of bool """ if utils.table_exists(self.spark, table.get_table_path()): - partitions = utils.get_partitions(self.reader, table) - return [(date in partitions) for date in dates] + checks = [] + for check_date in dates: + df = self.reader.get_table( + table.get_table_path(), + date_column=table.partition, + date_from=check_date, + date_to=check_date, + filters=target_filters, + ) + data_exists = df.count() > 0 + if data_exists and target_filters is None and table.secondary_partitions is not None: + # ensure rerun if the write consideres secondary partitions but the filter doesn't + data_exists = False + checks.append(data_exists) + return checks else: logger.info(f"Table {table.get_table_path()} doesn't exist!") return [False for _ in dates] @@ -145,7 +169,7 @@ def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") source = Table(table_path=dependency.table, partition=dependency.date_col) - if True in self.check_dates_have_partition(source, possible_dep_dates): + if True in self.check_dates_have_data(source, possible_dep_dates, dependency.filters): logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") else: msg = f"Missing dependency for {dependency.table} from {dep_from} until {run_date}" @@ -158,7 +182,7 @@ def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: return True - def _get_completion(self, target: Table, info_dates: List[date]) -> List[bool]: + def _get_completion(self, target: Table, info_dates: List[date], filters: Dict = None) -> List[bool]: """ Check if model has run for given dates @@ -169,9 +193,9 @@ def _get_completion(self, target: Table, info_dates: List[date]) -> List[bool]: if self.rerun: return [False for _ in info_dates] else: - return self.check_dates_have_partition(target, info_dates) + return self.check_dates_have_data(target, info_dates, filters) - def _select_run_dates(self, pipeline: PipelineConfig, table: Table) -> Tuple[List, List]: + def _select_run_dates(self, pipeline: PipelineConfig, table: Table, filters: Dict = None) -> Tuple[List, List]: """ Select run dates and info dates based on completion @@ -181,7 +205,7 @@ def _select_run_dates(self, pipeline: PipelineConfig, table: Table) -> Tuple[Lis """ possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule) possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates] - current_state = self._get_completion(table, possible_info_dates) + current_state = self._get_completion(table, possible_info_dates, filters) selection = [ (run, info) for run, info, state in zip(possible_run_dates, possible_info_dates, current_state) if not state @@ -212,7 +236,7 @@ def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: dat feature_group = utils.load_module(pipeline.module) df = self._execute(feature_group, run_date, pipeline) self.writer.write(df, info_date, target) - records = self._check_written(info_date, target) + records = self._check_written(info_date, target, df, pipeline) logger.info(f"Generated {records} records") if records == 0: raise RuntimeError("No records generated") @@ -231,10 +255,14 @@ def _run_pipeline(self, pipeline: PipelineConfig): schema_path=pipeline.target.target_schema, class_name=pipeline.module.python_class, partition=pipeline.target.target_partition_column, + secondary_partitions=pipeline.target.secondary_partition_columns, + table=pipeline.target.custom_name, ) logger.info(f"Loaded pipeline {pipeline.name}") - selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) + selected_run_dates, selected_info_dates = self._select_run_dates( + pipeline, target, pipeline.target.rerun_filters + ) # ----------- Checking dependencies available ---------- for run_date, info_date in zip(selected_run_dates, selected_info_dates): @@ -317,6 +345,8 @@ def debug(self) -> DataFrame: schema_path=pipeline.target.target_schema, class_name=pipeline.module.python_class, partition=pipeline.target.target_partition_column, + secondary_partitions=pipeline.target.secondary_partition_columns, + table=pipeline.target.custom_name, ) selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) if len(selected_run_dates) > 0: diff --git a/rialto/runner/table.py b/rialto/runner/table.py index 416bedb..2d44498 100644 --- a/rialto/runner/table.py +++ b/rialto/runner/table.py @@ -14,6 +14,8 @@ __all__ = ["Table"] +from typing import List + from rialto.metadata import class_to_catalog_name @@ -29,11 +31,13 @@ def __init__( table_path: str = None, class_name: str = None, partition: str = None, + secondary_partitions: List[str] = None, ): self.catalog = catalog self.schema = schema self.table = table self.partition = partition + self.secondary_partitions = secondary_partitions if schema_path: schema_path = schema_path.split(".") self.catalog = schema_path[0] @@ -43,13 +47,20 @@ def __init__( self.catalog = table_path[0] self.schema = table_path[1] self.table = table_path[2] - if class_name: + if class_name and not table: self.table = class_to_catalog_name(class_name) - def get_schema_path(self): + def get_schema_path(self) -> str: """Get path of table's schema""" return f"{self.catalog}.{self.schema}" - def get_table_path(self): + def get_table_path(self) -> str: """Get full table path""" return f"{self.catalog}.{self.schema}.{self.table}" + + def get_all_partitions(self) -> List[str]: + """Get list of all partitions""" + if self.secondary_partitions: + return [self.partition] + self.secondary_partitions + else: + return [self.partition] diff --git a/rialto/runner/writer.py b/rialto/runner/writer.py index 3e31dac..bc147fd 100644 --- a/rialto/runner/writer.py +++ b/rialto/runner/writer.py @@ -74,6 +74,23 @@ def _process(self, df: DataFrame, info_date: date, table: Table) -> DataFrame: return df + def _get_replace_condition(self, df: DataFrame, partition_cols: List[str]) -> str: + row = df.select(*partition_cols).distinct().collect() + if len(row) > 1: + raise ValueError(f"Some of the partitions to write have more than 1 distinct value \n {row}") + + parts = [] + for c in partition_cols: + val = row[0][c] + if val is None: + parts.append(f"{c} IS NULL") + elif isinstance(val, (int, float)): + parts.append(f"{c} = {val}") + else: + parts.append(f"{c} = '{val}'") + condition = " AND ".join(parts) + return condition + def write(self, df: DataFrame, info_date: date, table: Table) -> None: """ Write dataframe to storage @@ -87,10 +104,10 @@ def write(self, df: DataFrame, info_date: date, table: Table) -> None: df = self._process(df, info_date, table) - if self.merge_schema is True: - df.write.partitionBy(table.partition).mode("overwrite").option("mergeSchema", "true").saveAsTable( - table.get_table_path() - ) - else: - df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path()) + replace_where = self._get_replace_condition(df, table.get_all_partitions()) + + df.write.format("delta").partitionBy(table.partition).mode("overwrite").option( + "mergeSchema", "true" if self.merge_schema else "false" + ).option("replaceWhere", replace_where).saveAsTable(table.get_table_path()) + logger.info(f"Results writen to {table.get_table_path()}") diff --git a/tests/common/test_reader.py b/tests/common/test_reader.py index c42b20b..452ad13 100644 --- a/tests/common/test_reader.py +++ b/tests/common/test_reader.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datetime import date + import pytest from rialto.common.table_reader import TableReader @@ -27,7 +29,85 @@ def sample_df(spark): return df +@pytest.fixture +def multi_partition_df(spark): + df = spark.createDataFrame( + [ + ("REGION_A", "TYPE_X", date(2023, 1, 1), 100), + ("REGION_A", "TYPE_X", date(2023, 1, 2), 200), + ("REGION_A", "TYPE_Y", date(2023, 1, 1), 150), + ("REGION_A", "TYPE_Y", date(2023, 1, 2), 250), + ("REGION_B", "TYPE_X", date(2023, 1, 1), 300), + ("REGION_B", "TYPE_X", date(2023, 1, 2), 400), + ], + schema="region string, type string, info_date date, value int", + ) + return df + + def test_uppercase_columns(spark, sample_df): tr = TableReader(spark) df = tr._uppercase_column_names(sample_df) assert df.columns == ["A", "B", "C", "D", "E"] + + +def test_get_latest_with_single_filter(multi_partition_df, mocker): + mock_spark = mocker.MagicMock() + mock_spark.read.table.return_value = multi_partition_df + + tr = TableReader(mock_spark) + result = tr.get_latest("test_table", date_column="info_date", filters={"region": "REGION_A"}) + + assert result.count() == 2 + assert all(row.info_date == date(2023, 1, 2) for row in result.collect()) + assert all(row.region == "REGION_A" for row in result.collect()) + + +def test_get_latest_with_multiple_filters(multi_partition_df, mocker): + mock_spark = mocker.MagicMock() + mock_spark.read.table.return_value = multi_partition_df + + tr = TableReader(mock_spark) + result = tr.get_latest("test_table", date_column="info_date", filters={"region": "REGION_A", "type": "TYPE_X"}) + + assert result.count() == 1 + assert result.first().info_date == date(2023, 1, 2) + assert result.first().region == "REGION_A" + assert result.first().type == "TYPE_X" + assert result.first().value == 200 + + +def test_get_latest_without_filters(multi_partition_df, mocker): + mock_spark = mocker.MagicMock() + mock_spark.read.table.return_value = multi_partition_df + + tr = TableReader(mock_spark) + result = tr.get_latest("test_table", date_column="info_date") + + assert result.count() == 3 + assert all(row.info_date == date(2023, 1, 2) for row in result.collect()) + + +def test_get_table_with_multiple_filters(multi_partition_df, mocker): + mock_spark = mocker.MagicMock() + mock_spark.read.table.return_value = multi_partition_df + + tr = TableReader(mock_spark) + result = tr.get_table("test_table", date_column="info_date", filters={"region": "REGION_A", "type": "TYPE_X"}) + + assert result.count() == 2 + + assert result.first().info_date == date(2023, 1, 1) + assert result.first().region == "REGION_A" + assert result.first().type == "TYPE_X" + assert result.first().value == 100 + + +def test_get_table_without_filters(multi_partition_df, mocker): + mock_spark = mocker.MagicMock() + mock_spark.read.table.return_value = multi_partition_df + + tr = TableReader(mock_spark) + result = tr.get_table("test_table", date_column="info_date") + + assert result.count() == 6 diff --git a/tests/runner/runner_resources.py b/tests/runner/runner_resources.py index bd39947..17b5447 100644 --- a/tests/runner/runner_resources.py +++ b/tests/runner/runner_resources.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pyspark.sql.types import DateType, StringType, StructField, StructType +from pyspark.sql.types import DateType, IntegerType, StringType, StructField, StructType from rialto.runner.date_manager import DateManager @@ -23,7 +23,6 @@ general_schema = StructType([StructField("KEY", StringType(), True), StructField("DATE", DateType(), True)]) - dep1_data = [ ("E", DateManager.str_to_date("2023-03-05")), ("F", DateManager.str_to_date("2023-03-10")), @@ -36,3 +35,20 @@ ("K", DateManager.str_to_date("2022-12-01")), ("L", DateManager.str_to_date("2023-01-01")), ] + +multi_part_data = [ + ("W", 1, "A", DateManager.str_to_date("2023-03-05")), + ("E", 1, "B", DateManager.str_to_date("2023-03-05")), + ("R", 2, "B", DateManager.str_to_date("2023-03-05")), + ("T", 1, "B", DateManager.str_to_date("2023-03-12")), + ("Y", 2, "A", DateManager.str_to_date("2023-03-19")), +] + +multi_schema = StructType( + [ + StructField("VALUE", StringType(), True), + StructField("VERSION", IntegerType(), True), + StructField("TYPE", StringType(), True), + StructField("DATE", DateType(), True), + ] +) diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 352c555..cbb4b7b 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -11,52 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime -from typing import Optional - import pytest from pyspark.sql import DataFrame import rialto.runner.utils as utils -from rialto.common.table_reader import DataReader +from rialto.common.table_reader import TableReader from rialto.runner.runner import DateManager, Runner from rialto.runner.table import Table from tests.runner.runner_resources import ( dep1_data, dep2_data, general_schema, + multi_part_data, + multi_schema, simple_group_data, ) from tests.runner.transformations.simple_group import SimpleGroup -class MockReader(DataReader): +class MockReader(TableReader): def __init__(self, spark): self.spark = spark - def get_table( - self, - table: str, - date_from: Optional[datetime.date] = None, - date_to: Optional[datetime.date] = None, - date_column: str = None, - uppercase_columns: bool = False, - ) -> DataFrame: + def _get_raw_data(self, table: str) -> DataFrame: if table == "catalog.schema.simple_group": return self.spark.createDataFrame(simple_group_data, general_schema) if table == "source.schema.dep1": return self.spark.createDataFrame(dep1_data, general_schema) if table == "source.schema.dep2": return self.spark.createDataFrame(dep2_data, general_schema) - - def get_latest( - self, - table: str, - date_until: Optional[datetime.date] = None, - date_column: str = None, - uppercase_columns: bool = False, - ) -> DataFrame: - pass + if table == "source.schema.multi_part_data": + return self.spark.createDataFrame(multi_part_data, multi_schema) def test_table_exists(spark, mocker): @@ -150,6 +135,37 @@ def test_completion_rerun(spark, mocker, basic_runner): assert comp == expected +def test_completion_secondary_partitions(spark, mocker, basic_runner): + mocker.patch("rialto.runner.utils.table_exists", return_value=True) + + basic_runner.reader = MockReader(spark) + + dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] + dates = [DateManager.str_to_date(d) for d in dates] + filters = {"version": 1, "type": "A"} + + comp = basic_runner._get_completion( + Table(table_path="source.schema.multi_part_data", partition="DATE"), dates, filters + ) + expected = [False, True, False, False, False] + assert comp == expected + + +def test_completion_secondary_partitions_no_filter(spark, mocker, basic_runner): + mocker.patch("rialto.runner.utils.table_exists", return_value=True) + + basic_runner.reader = MockReader(spark) + + dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] + dates = [DateManager.str_to_date(d) for d in dates] + + comp = basic_runner._get_completion( + Table(table_path="source.schema.multi_part_data", partition="DATE", secondary_partitions=["VERSION"]), dates + ) + expected = [False, False, False, False, False] + assert comp == expected + + def test_check_dates_have_partition(spark, mocker): mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) @@ -161,7 +177,7 @@ def test_check_dates_have_partition(spark, mocker): runner.reader = MockReader(spark) dates = ["2023-03-04", "2023-03-05", "2023-03-06"] dates = [DateManager.str_to_date(d) for d in dates] - res = runner.check_dates_have_partition(Table(schema_path="source.schema", table="dep1", partition="DATE"), dates) + res = runner.check_dates_have_data(Table(schema_path="source.schema", table="dep1", partition="DATE"), dates) expected = [False, True, False] assert res == expected @@ -176,7 +192,7 @@ def test_check_dates_have_partition_no_table(spark, mocker): ) dates = ["2023-03-04", "2023-03-05", "2023-03-06"] dates = [DateManager.str_to_date(d) for d in dates] - res = runner.check_dates_have_partition(Table(schema_path="source.schema", table="dep66", partition="DATE"), dates) + res = runner.check_dates_have_data(Table(schema_path="source.schema", table="dep66", partition="DATE"), dates) expected = [False, False, False] assert res == expected @@ -198,6 +214,23 @@ def test_check_dependencies(spark, mocker, r_date, expected): assert res == expected +@pytest.mark.parametrize( + "r_date, expected", + [("2023-03-19", True), ("2023-03-18", False)], +) +def test_check_dependencies_filter(spark, mocker, r_date, expected): + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) + + runner = Runner( + spark, + config_path="tests/runner/transformations/config3.yaml", + run_date="2023-03-19", + ) + runner.reader = MockReader(spark) + res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date)) + assert res == expected + + def test_check_no_dependencies(spark, mocker): mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) @@ -288,3 +321,10 @@ def test_bookkeeping_inactive(spark, mocker): runner = Runner(spark, config_path="tests/runner/transformations/config2.yaml") assert runner.config.runner.bookkeeping is None + + +def test_config_multi_partition(spark, mocker): + runner = Runner(spark, config_path="tests/runner/transformations/config3.yaml") + assert runner.config.pipelines[0].target.secondary_partition_columns == ["VERSION", "ENV"] + assert runner.config.pipelines[0].dependencies[0].filters == {"VERSION": 2, "TYPE": "A"} + assert runner.config.pipelines[0].target.rerun_filters == {"version": 3, "env": "dev"} diff --git a/tests/runner/test_table.py b/tests/runner/test_table.py index 82e6fa6..f1e4ead 100644 --- a/tests/runner/test_table.py +++ b/tests/runner/test_table.py @@ -26,3 +26,25 @@ def test_table_path_init(): assert t.catalog == "cat" assert t.schema == "sch" assert t.table == "tab" + + +def test_table_secondary_partitions(): + t = Table(catalog="cat", schema="sch", table="tab", partition="part", secondary_partitions=["sec1", "sec2"]) + + assert t.get_all_partitions() == ["part", "sec1", "sec2"] + + +def test_table_get_partitions_only_main(): + t = Table(catalog="cat", schema="sch", table="tab", partition="part") + + assert t.get_all_partitions() == ["part"] + + +def test_table_prioritize_table_name(): + t = Table(catalog=None, schema=None, table="custom", schema_path="cat.sch", table_path=None, class_name="ClaSs") + + assert t.get_table_path() == "cat.sch.custom" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "custom" diff --git a/tests/runner/test_writer.py b/tests/runner/test_writer.py new file mode 100644 index 0000000..5ec9ec6 --- /dev/null +++ b/tests/runner/test_writer.py @@ -0,0 +1,58 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import date + +import pytest + +from rialto.runner.writer import Writer + + +@pytest.fixture +def sample_multi_partition(spark): + df = spark.createDataFrame( + [ + ("REGION_A", 3, date(2023, 1, 1), 100), + ("REGION_A", 3, date(2023, 1, 1), 300), + ], + schema="region string, version int, info_date date, value int", + ) + return df + + +@pytest.fixture +def sample_multi_partition_non_unique(spark): + df = spark.createDataFrame( + [ + ("REGION_A", 1, date(2023, 1, 1), 100), + ("REGION_A", 2, date(2023, 1, 1), 300), + ], + schema="region string, version int, info_date date, value int", + ) + return df + + +def test_replace_condition(sample_multi_partition): + writer = Writer(spark=None) + condition = writer._get_replace_condition(sample_multi_partition, partition_cols=["region", "version", "info_date"]) + expected_condition = "region = 'REGION_A' AND version = 3 AND info_date = '2023-01-01'" + assert condition == expected_condition + + +def test_replace_condition_non_unique(sample_multi_partition_non_unique): + writer = Writer(spark=None) + with pytest.raises(ValueError): + writer._get_replace_condition( + sample_multi_partition_non_unique, partition_cols=["region", "version", "info_date"] + ) diff --git a/tests/runner/transformations/config3.yaml b/tests/runner/transformations/config3.yaml new file mode 100644 index 0000000..72af1da --- /dev/null +++ b/tests/runner/transformations/config3.yaml @@ -0,0 +1,49 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +runner: + watched_period_units: "weeks" + watched_period_value: 5 + mail: + sender: test@testing.org + smtp: server.test + to: + - developer@testing.org + subject: test report +pipelines: +- name: SimpleGroup + module: + python_module: transformations + python_class: SimpleGroup + schedule: + frequency: weekly + day: 7 + dependencies: + - table: source.schema.multi_part_data + interval: + units: "days" + value: 1 + date_col: "DATE" + filters: + VERSION: 2 + TYPE: "A" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" + secondary_partition_columns: + - "VERSION" + - "ENV" + rerun_filters: + version: 3 + env: "dev"