From c70ce433e0e807ee57696826eb7ea23c980201b7 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Wed, 29 Apr 2026 16:54:23 +0200 Subject: [PATCH 1/4] runner refactor wip --- rialto/common/table_reader.py | 9 + rialto/runner/config_loader.py | 11 +- rialto/runner/data_checker.py | 76 ++++++ rialto/runner/date_manager.py | 66 ++++-- rialto/runner/execution_planner.py | 127 ++++++++++ rialto/runner/runner.py | 356 +++++------------------------ rialto/runner/runner_old.py | 342 +++++++++++++++++++++++++++ rialto/runner/table.py | 41 +++- rialto/runner/utils.py | 33 +-- rialto/runner/writer.py | 23 +- tests/runner/test_date_manager.py | 14 +- tests/runner/test_table.py | 4 +- 12 files changed, 745 insertions(+), 357 deletions(-) create mode 100644 rialto/runner/data_checker.py create mode 100644 rialto/runner/execution_planner.py create mode 100644 rialto/runner/runner_old.py diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py index 228d59b..fe910c8 100644 --- a/rialto/common/table_reader.py +++ b/rialto/common/table_reader.py @@ -165,3 +165,12 @@ def get_table( if uppercase_columns: df = self._uppercase_column_names(df) return df + + def table_exists(self, table: str) -> bool: + """ + Check table exists in spark catalog + + :param table: full table path + :return: bool + """ + return self.spark.catalog.tableExists(table) diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index 7978ac5..a3ae621 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -13,7 +13,7 @@ # limitations under the License. __all__ = [ - "get_pipelines_config", + "ConfigLoader", ] from typing import Dict, List, Optional @@ -108,3 +108,12 @@ def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: return PipelinesConfig(**cfg) else: return PipelinesConfig(**raw_config) + + +class ConfigLoader: + """Loader for pipelines config""" + + @staticmethod + def load_yaml(path: str, overrides: Dict) -> PipelinesConfig: + """Load yaml config and apply overrides""" + return get_pipelines_config(path, overrides) diff --git a/rialto/runner/data_checker.py b/rialto/runner/data_checker.py new file mode 100644 index 0000000..1af1c31 --- /dev/null +++ b/rialto/runner/data_checker.py @@ -0,0 +1,76 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date + +from loguru import logger +from pyspark.sql import DataFrame + +from rialto.common import TableReader +from rialto.runner.table import Table + + +class DataChecker: + """Checks if data for given date or date range is present in storage""" + + def __init__(self, reader: TableReader): + self.reader = reader + + def check_date(self, target: Table, partition_date: date) -> bool: + """Check if data for given date is present in target""" + return self.check_range(target, partition_date, partition_date) + + def check_range(self, target: Table, start_date: date, end_date: date) -> bool: + """Check if data for given date range is present in target""" + if self.reader.table_exists(target.get_table_path()): + df = self.reader.get_table( + target.get_table_path(), + date_column=target.partition, + date_from=start_date, + date_to=end_date, + filters=target.filters, + ) + data_exists = df.count() > 0 + if data_exists and target.filters is None and target.secondary_partitions is not None: + # dependencies don't have secondary partitions, this is skipped + logger.info( + f"Overwriting {target.get_table_path()} completion status for {start_date} due to presence of " + f"secondary partitions and no filters." + ) + data_exists = False + return data_exists + else: + logger.info(f"Table {target.get_table_path()} doesn't exist!") + return False + + def check_written(self, target: Table, partition_date: date, df: DataFrame) -> int: + """Check how many records were written""" + filters = {} + if target.filters is not None: + filters = target.filters + else: + if target.secondary_partitions: + row = df.select(*target.secondary_partitions).distinct().collect()[0] + for c in target.secondary_partitions: + val = row[0][c] + filters[c] = val + + df = self.reader.get_table( + target.get_table_path(), + date_column=target.partition, + date_from=partition_date, + date_to=partition_date, + filters=filters, + ) + + return df.count() diff --git a/rialto/runner/date_manager.py b/rialto/runner/date_manager.py index 1bcef7b..e3cd043 100644 --- a/rialto/runner/date_manager.py +++ b/rialto/runner/date_manager.py @@ -17,7 +17,9 @@ from datetime import date, datetime from typing import List +from config_loader import PipelinesConfig from dateutil.relativedelta import relativedelta +from loguru import logger from rialto.runner.config_loader import ScheduleConfig @@ -25,8 +27,46 @@ class DateManager: """Date generation and shifts based on configuration""" - @staticmethod - def str_to_date(str_date: str) -> date: + def __init__(self, config: PipelinesConfig, run_date: date = None): + if run_date: + run_date = DateManager.str_to_date(run_date) + else: + run_date = date.today() + + self.date_from = self.date_subtract( + run_date=run_date, + units=self.config.runner.watched_period_units, + value=self.config.runner.watched_period_value, + ) + + self.date_until = run_date + + if self.date_from > self.date_until: + raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") + logger.info(f"Running period set to: {self.date_from} - {self.date_until}") + + def get_date_from(self) -> date: + """Get starting date of the execution window""" + return self.date_from + + def get_date_until(self) -> date: + """Get ending date of the execution window""" + return self.date_until + + def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[(date, date)]: + """ + Get list of execution and partition dates for given configuration + + :return: List of tuples with execution and partition dates + """ + datepairs = [] + execution = self.execution_dates(schedule) + for ex_date in execution: + partition = self.to_partition_date(ex_date, schedule) + datepairs.append((ex_date, partition)) + return datepairs + + def str_to_date(self, str_date: str) -> date: """ Convert YYYY-MM-DD string to date @@ -35,8 +75,7 @@ def str_to_date(str_date: str) -> date: """ return datetime.strptime(str_date, "%Y-%m-%d").date() - @staticmethod - def date_subtract(run_date: date, units: str, value: int) -> date: + def date_subtract(self, run_date: date, units: str, value: int) -> date: """ Generate starting date from given date and config @@ -55,8 +94,7 @@ def date_subtract(run_date: date, units: str, value: int) -> date: return run_date - relativedelta(days=value) raise ValueError(f"Unknown time unit {units}") - @staticmethod - def all_dates(date_from: date, date_to: date) -> List[date]: + def all_dates(self, date_from: date, date_to: date) -> List[date]: """ Get list of all dates between, inclusive @@ -69,17 +107,14 @@ def all_dates(date_from: date, date_to: date) -> List[date]: return [date_from + relativedelta(days=n) for n in range((date_to - date_from).days + 1)] - @staticmethod - def run_dates(date_from: date, date_to: date, schedule: ScheduleConfig) -> List[date]: + def execution_dates(self, schedule: ScheduleConfig) -> List[date]: """ Select dates inside given interval depending on frequency and selected day - :param date_from: interval start - :param date_to: interval end :param schedule: schedule config :return: list of dates """ - options = DateManager.all_dates(date_from, date_to) + options = self.all_dates(self.date_from, self.date_to) if schedule.frequency == "daily": return options if schedule.frequency == "weekly": @@ -88,8 +123,7 @@ def run_dates(date_from: date, date_to: date, schedule: ScheduleConfig) -> List[ return [x for x in options if x.day == schedule.day] raise ValueError(f"Unknown frequency {schedule.frequency}") - @staticmethod - def to_info_date(date: date, schedule: ScheduleConfig) -> date: + def to_partition_date(self, date: date, schedule: ScheduleConfig) -> date: """ Shift given date according to config @@ -99,9 +133,7 @@ def to_info_date(date: date, schedule: ScheduleConfig) -> date: """ if isinstance(schedule.info_date_shift, List): for shift in schedule.info_date_shift: - date = DateManager.date_subtract(date, units=shift.units, value=shift.value) + date = self.date_subtract(date, units=shift.units, value=shift.value) else: - date = DateManager.date_subtract( - date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value - ) + date = self.date_subtract(date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value) return date diff --git a/rialto/runner/execution_planner.py b/rialto/runner/execution_planner.py new file mode 100644 index 0000000..3725f2b --- /dev/null +++ b/rialto/runner/execution_planner.py @@ -0,0 +1,127 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from datetime import date +from typing import List + +from config_loader import PipelineConfig +from date_manager import DateManager +from loguru import logger + +from rialto.runner.data_checker import DataChecker +from rialto.runner.table import Table + + +@dataclass +class Dependency: + """Class representing a pipeline dependency, with associated table and date range for checking completion""" + + table: Table + date_from: date + date_until: date + complete: bool = False + + +@dataclass +class Pipeline: + """Class representing a pipeline to be executed.""" + + op: str + execution_date: date + partition_date: date + config: PipelineConfig + target: Table + dependencies: List = field(default_factory=list) + completion: bool = False + dependencies_complete: bool = False + + def check_completion(self, checker: DataChecker, rerun: bool) -> None: + """ + Check if pipeline is complete by checking if target data exists for partition date + + :param checker: DataChecker instance to use for checking data presence + :param rerun: If True, skip completion check to allow re-running of completed pipelines + + :return: None, updates self.completion attribute + """ + if not rerun: + self.completion = checker.check_date(self.target, self.partition_date) + logger.info(f"Job {self.op} completion status for partition date {self.partition_date}: {self.completion}") + + def check_dependencies_complete(self, checker, skip_dependencies: bool) -> None: + """ + Check if dependencies are complete by checking if data exists for each dependency in date range + + :param checker: DataChecker instance to use for checking data presence + :param skip_dependencies: Skip dependency checks to allow running pipelines with incomplete dependencies + + :return: None, updates self.dependencies_complete attribute + """ + if not skip_dependencies: + for dependency in self.dependencies: + dependency.complete = checker.check_range(dependency.table, dependency.date_from, dependency.date_until) + logger.info( + f"Dependency {dependency.table.get_table_path()} completion status for date range " + f"{dependency.date_from} - {dependency.date_until}: {dependency.complete}" + ) + self.dependencies_complete = all([dependency.complete for dependency in self.dependencies]) + + +class ExecutionPlanner: + """Planner for pipeline execution, managing tasks and their dependencies""" + + def __init__(self, date_manager: DateManager): + self.date_manager = date_manager + self.tasks = [] + + def add_pipeline(self, name: str, execution_date: date, partition_date: date, config: PipelineConfig) -> None: + """ + Add pipeline to execution plan + + :param name: Name of the pipeline + :param execution_date: Date when the pipeline is scheduled to run + :param partition_date: Date for which the pipeline is processing data + :param config: PipelineConfig object with pipeline configuration + + :return: None, adds a Pipeline object to self.tasks + """ + target = Table.from_target_config(config) + new_pipe = Pipeline( + op=name, execution_date=execution_date, partition_date=partition_date, config=config, target=target + ) + + for dependency_config in config.dependencies: + dependency_table = Table.from_dependency_config(dependency_config) + dependency_from = self.date_manager.date_subtract( + execution_date, dependency_config.interval.units, dependency_config.interval.value + ) + dependency = Dependency(table=dependency_table, date_from=dependency_from, date_until=execution_date) + new_pipe.dependencies.append(dependency) + + self.tasks.append(new_pipe) + + def __iter__(self): + """Allow iteration over tasks in execution plan""" + return iter(self.tasks) + + def log_status(self) -> None: + """Log status of all tasks in execution plan, showing completion and dependency status""" + check = "\u2714" # ✔ + cross = "\u2718" # ✘ + logger.info(f"{'Job Name':<25} {'Partition Date':<15} {'Complete':<10} {'Deps Complete':<15}") + logger.info("-" * 70) + for task in self.tasks: + complete_icon = check if task.completion else cross + deps_icon = check if task.dependencies_complete else cross + logger.info(f"{task.op:<25} {str(task.partition_date):<15} {complete_icon:^10} {deps_icon:^15}") diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 384998f..fb743c0 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -14,22 +14,21 @@ __all__ = ["Runner"] -import datetime -from datetime import date -from typing import Dict, List, Tuple +from datetime import datetime +from typing import Dict, List +from execution_planner import ExecutionPlanner from loguru import logger from pyspark.sql import DataFrame, SparkSession import rialto.runner.utils as utils from rialto.common import TableReader -from rialto.runner.config_loader import PipelineConfig, get_pipelines_config +from rialto.runner.config_loader import ConfigLoader, PipelineConfig +from rialto.runner.data_checker import DataChecker from rialto.runner.date_manager import DateManager from rialto.runner.reporting.record import Record from rialto.runner.reporting.tracker import Tracker -from rialto.runner.table import Table -from rialto.runner.transformation import Transformation -from rialto.runner.writer import Writer +from rialto.runner.writer import DatabricksWriter class Runner: @@ -47,310 +46,79 @@ def __init__( merge_schema: bool = False, ): self.spark = spark - self.config = get_pipelines_config(config_path, overrides) + self.config = ConfigLoader().load_yaml(config_path, overrides) self.reader = TableReader(spark) self.rerun = rerun - self.skip_dependencies = skip_dependencies self.op = op - self.writer = Writer(spark, merge_schema=merge_schema) + self.skip_dependencies = skip_dependencies + self.writer = DatabricksWriter(spark, merge_schema=merge_schema) + self.checker = DataChecker(self.reader) self.tracker = Tracker( mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark ) + self.date_manager = DateManager(self.config, run_date) + self.planner = ExecutionPlanner() - if run_date: - run_date = DateManager.str_to_date(run_date) - else: - run_date = date.today() - - self.date_from = DateManager.date_subtract( - run_date=run_date, - units=self.config.runner.watched_period_units, - value=self.config.runner.watched_period_value, - ) - - self.date_until = run_date - - if self.date_from > self.date_until: - raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") - logger.info(f"Running period set to: {self.date_from} - {self.date_until}") - - def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: - """ - Run the job - - :param instance: Instance of Transformation - :param run_date: date to run for - :param pipeline: pipeline configuration - :return: Dataframe - """ - metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) - - df = instance.run( - spark=self.spark, - run_date=run_date, - config=pipeline, - reader=self.reader, - metadata_manager=metadata_manager, - feature_loader=feature_loader, - ) - - return df - - def _check_written(self, info_date: date, table: Table, df: DataFrame, pipeline: PipelineConfig) -> int: - """ - Check if there are records written for given date - - :param info_date: date to check - :param table: target table object - :return: number of records - """ - filters = {} - if pipeline.target.rerun_filters is not None: - filters = pipeline.target.rerun_filters + def _select_pipelines(self) -> List[PipelineConfig]: + """Select pipelines to run based on config and input parameters""" + if self.op: + selected = [p for p in self.config.pipelines if p.name == self.op] + if len(selected) < 1: + raise ValueError(f"Unknown operation selected: {self.op}") + return selected else: - if table.secondary_partitions: - row = df.select(*table.secondary_partitions).distinct().collect()[0] - for c in table.secondary_partitions: - val = row[0][c] - filters[c] = val - - df = self.reader.get_table( - table.get_table_path(), date_column=table.partition, date_from=info_date, date_to=info_date, filters=filters - ) - - return df.count() + return self.config.pipelines - def check_dates_have_data(self, table: Table, dates: List[date], target_filters: Dict = None) -> List[bool]: - """ - For given list of dates, check if there is a matching partition for each - - :param table: Table object - :param dates: list of dates to check - :return: list of bool - """ - if utils.table_exists(self.spark, table.get_table_path()): - checks = [] - for check_date in dates: - df = self.reader.get_table( - table.get_table_path(), - date_column=table.partition, - date_from=check_date, - date_to=check_date, - filters=target_filters, + def __call__(self): + """Execute pipelines""" + pipelines = self._select_pipelines() + + # Register pipelines in execution planner with their execution and partition dates + for pipeline in pipelines: + exec_date, partition_date = self.date_manager.get_execution_and_partition_dates(pipeline.schedule) + self.planner.add_pipeline( + name=pipeline.name, execution_date=exec_date, partition_date=partition_date, config=pipeline + ) + + for task in self.planner.tasks: + task.check_completion(self.checker, self.rerun) + task.check_dependencies_complete(self.checker, self.skip_dependencies) + + self.planner.log_status() + + # TODO everything bellow is just temporary + for task in self.planner.tasks: + if not task.completion and task.dependencies_complete: + logger.info(f"Running pipeline {task.op} for partition date {task.partition_date}") + job = utils.load_module(task.config.module) + metadata_manager, feature_loader = utils.init_tools(self.spark, task.config) + run_start = datetime.now() + df = job.run( + spark=self.spark, + run_date=task.execution_date, + config=task.config, + reader=self.reader, + metadata_manager=metadata_manager, + feature_loader=feature_loader, ) - data_exists = df.count() > 0 - if data_exists and target_filters is None and table.secondary_partitions is not None: - # ensure rerun if the write consideres secondary partitions but the filter doesn't - data_exists = False - checks.append(data_exists) - return checks - else: - logger.info(f"Table {table.get_table_path()} doesn't exist!") - return [False for _ in dates] - - def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: - """ - Check for all dependencies in config if they have available partitions - - :param pipeline: configuration - :param run_date: run date - :return: bool - """ - logger.info(f"{pipeline.name} checking dependencies for {run_date}") - - error = "" - - for dependency in pipeline.dependencies: - dep_from = DateManager.date_subtract(run_date, dependency.interval.units, dependency.interval.value) - logger.info(f"Looking for {dependency.table} from {dep_from} until {run_date}") - - possible_dep_dates = DateManager.all_dates(dep_from, run_date) - - logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") - - source = Table(table_path=dependency.table, partition=dependency.date_col) - if True in self.check_dates_have_data(source, possible_dep_dates, dependency.filters): - logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") - else: - msg = f"Missing dependency for {dependency.table} from {dep_from} until {run_date}" - logger.info(msg) - error = error + msg + "\n" - - if error != "": - self.tracker.last_error = error - return False - - return True - - def _get_completion(self, target: Table, info_dates: List[date], filters: Dict = None) -> List[bool]: - """ - Check if model has run for given dates + self.writer.write(df, task.partition_date, task.target) + records = self.checker.check_written(task.target, task.package, df) - :param target_path: Table object - :param info_dates: list of dates - :return: bool list - """ - if self.rerun: - return [False for _ in info_dates] - else: - return self.check_dates_have_data(target, info_dates, filters) - - def _select_run_dates(self, pipeline: PipelineConfig, table: Table, filters: Dict = None) -> Tuple[List, List]: - """ - Select run dates and info dates based on completion - - :param pipeline: pipeline config - :param table: table path - :return: list of run dates and list of info dates - """ - possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule) - possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates] - current_state = self._get_completion(table, possible_info_dates, filters) - - selection = [ - (run, info) for run, info, state in zip(possible_run_dates, possible_info_dates, current_state) if not state - ] - - if not len(selection): - logger.info(f"{pipeline.name} has no dates to run") - return [], [] - - selected_run_dates, selected_info_dates = zip(*selection) - logger.info(f"{pipeline.name} identified to run for {selected_run_dates}") - - return list(selected_run_dates), list(selected_info_dates) - - def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: date, target: Table) -> int: - """ - Run one pipeline for one date - - :param pipeline: pipeline cfg - :param run_date: run date - :param info_date: information date - :param target: target Table - :return: success bool - """ - if self.skip_dependencies or self.check_dependencies(pipeline, run_date): - logger.info(f"Running {pipeline.name} for {run_date}") - - feature_group = utils.load_module(pipeline.module) - df = self._execute(feature_group, run_date, pipeline) - self.writer.write(df, info_date, target) - records = self._check_written(info_date, target, df, pipeline) - logger.info(f"Generated {records} records") - if records == 0: - raise RuntimeError("No records generated") - else: - return records - return 0 - - def _run_pipeline(self, pipeline: PipelineConfig): - """ - Run single pipeline for all required dates - - :param pipeline: pipeline cfg - :return: success bool - """ - target = Table( - schema_path=pipeline.target.target_schema, - class_name=pipeline.module.python_class, - partition=pipeline.target.target_partition_column, - secondary_partitions=pipeline.target.secondary_partition_columns, - table=pipeline.target.custom_name, - ) - logger.info(f"Loaded pipeline {pipeline.name}") - - selected_run_dates, selected_info_dates = self._select_run_dates( - pipeline, target, pipeline.target.rerun_filters - ) - - # ----------- Checking dependencies available ---------- - for run_date, info_date in zip(selected_run_dates, selected_info_dates): - run_start = datetime.datetime.now() - try: - records = self._run_one_date(pipeline, run_date, info_date, target) - if records > 0: - status = "Success" - message = "" - else: - status = "Failure" - message = self.tracker.last_error self.tracker.add( Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, + job=task.op, + target=task.target.get_table_path(), + date=task.partition_date, + time=datetime.now() - run_start, records=records, - status=status, - reason=message, + status="status", + reason="message", ) ) - except Exception as error: - logger.error(f"An exception occurred in pipeline {pipeline.name}") - logger.exception(error) - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=0, - status="Error", - reason="Exception", - exception=str(error), - ) - ) - except KeyboardInterrupt: - logger.error(f"Pipeline {pipeline.name} interrupted") - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=0, - status="Error", - reason="Interrupted by user", - ) - ) - raise KeyboardInterrupt - def __call__(self): - """Execute pipelines""" - logger.info("Executing pipelines") - try: - if self.op: - selected = [p for p in self.config.pipelines if p.name == self.op] - if len(selected) < 1: - raise ValueError(f"Unknown operation selected: {self.op}") - self._run_pipeline(selected[0]) - else: - for pipeline in self.config.pipelines: - self._run_pipeline(pipeline) - finally: - print(self.tracker.records) - self.tracker.report_by_mail() - logger.info("Execution finished") + # 6. run the pipeline for dates with completed dependencies + # 7. write results + # 8. sumbit tracking def debug(self) -> DataFrame: """Debug mode - run only first op for one date and return the resulting dataframe""" - logger.info("Running in debug mode") - if self.op: - pipeline = [p for p in self.config.pipelines if p.name == self.op][0] - else: - pipeline = self.config.pipelines[0] - - target = Table( - schema_path=pipeline.target.target_schema, - class_name=pipeline.module.python_class, - partition=pipeline.target.target_partition_column, - secondary_partitions=pipeline.target.secondary_partition_columns, - table=pipeline.target.custom_name, - ) - selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) - if len(selected_run_dates) > 0: - df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline) - return self.writer._process(df, selected_info_dates[0], target) - else: - logger.info("No dates to run in debug mode") diff --git a/rialto/runner/runner_old.py b/rialto/runner/runner_old.py new file mode 100644 index 0000000..4265e26 --- /dev/null +++ b/rialto/runner/runner_old.py @@ -0,0 +1,342 @@ +# # Copyright 2022 ABSA Group Limited +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# +# __all__ = ["Runner"] +# +# import datetime +# from datetime import date +# from typing import Dict, List, Tuple +# +# from loguru import logger +# from pyspark.sql import DataFrame, SparkSession +# +# import rialto.runner.utils as utils +# from rialto.common import TableReader +# from rialto.runner.config_loader import ConfigLoader, PipelineConfig +# from rialto.runner.date_manager import DateManager +# from rialto.runner.reporting.record import Record +# from rialto.runner.reporting.tracker import Tracker +# from rialto.runner.table import Table +# from rialto.runner.transformation import Transformation +# from rialto.runner.writer import Writer +# +# +# class Runner: +# """A scheduler and dependency checker for feature runs""" +# +# def __init__( +# self, +# spark: SparkSession, +# config_path: str, +# run_date: str = None, +# rerun: bool = False, +# op: str = None, +# skip_dependencies: bool = False, +# overrides: Dict = None, +# merge_schema: bool = False, +# ): +# self.spark = spark +# self.config = ConfigLoader().load_yaml(config_path, overrides) +# self.reader = TableReader(spark) +# self.rerun = rerun +# self.skip_dependencies = skip_dependencies +# self.op = op +# self.writer = Writer(spark, merge_schema=merge_schema) +# self.tracker = Tracker( +# mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark +# ) +# self.date_manager = DateManager(self.config, run_date) +# +# def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: +# """ +# Run the job +# +# :param instance: Instance of Transformation +# :param run_date: date to run for +# :param pipeline: pipeline configuration +# :return: Dataframe +# """ +# metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) +# +# df = instance.run( +# spark=self.spark, +# run_date=run_date, +# config=pipeline, +# reader=self.reader, +# metadata_manager=metadata_manager, +# feature_loader=feature_loader, +# ) +# +# return df +# +# def _check_written(self, info_date: date, table: Table, df: DataFrame, pipeline: PipelineConfig) -> int: +# """ +# Check if there are records written for given date +# +# :param info_date: date to check +# :param table: target table object +# :return: number of records +# """ +# filters = {} +# if pipeline.target.rerun_filters is not None: +# filters = pipeline.target.rerun_filters +# else: +# if table.secondary_partitions: +# row = df.select(*table.secondary_partitions).distinct().collect()[0] +# for c in table.secondary_partitions: +# val = row[0][c] +# filters[c] = val +# +# df = self.reader.get_table( +# table.get_table_path(), date_column=table.partition, +# date_from=info_date, date_to=info_date, filters=filters +# ) +# +# return df.count() +# +# def check_dates_have_data(self, table: Table, dates: List[date], target_filters: Dict = None) -> List[bool]: +# """ +# For given list of dates, check if there is a matching partition for each +# +# :param table: Table object +# :param dates: list of dates to check +# :return: list of bool +# """ +# if self.reader.table_exists(self.spark, table.get_table_path()): +# checks = [] +# for check_date in dates: +# df = self.reader.get_table( +# table.get_table_path(), +# date_column=table.partition, +# date_from=check_date, +# date_to=check_date, +# filters=target_filters, +# ) +# data_exists = df.count() > 0 +# if data_exists and target_filters is None and table.secondary_partitions is not None: +# # ensure rerun if the write consideres secondary partitions but the filter doesn't +# data_exists = False +# checks.append(data_exists) +# return checks +# else: +# logger.info(f"Table {table.get_table_path()} doesn't exist!") +# return [False for _ in dates] +# +# def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: +# """ +# Check for all dependencies in config if they have available partitions +# +# :param pipeline: configuration +# :param run_date: run date +# :return: bool +# """ +# logger.info(f"{pipeline.name} checking dependencies for {run_date}") +# +# error = "" +# +# for dependency in pipeline.dependencies: +# dep_from = date_manager.date_subtract(run_date, dependency.interval.units, dependency.interval.value) +# logger.info(f"Looking for {dependency.table} from {dep_from} until {run_date}") +# +# possible_dep_dates = date_manager.all_dates(dep_from, run_date) +# +# logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") +# +# source = Table(table_path=dependency.table, partition=dependency.date_col) +# if True in self.check_dates_have_data(source, possible_dep_dates, dependency.filters): +# logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") +# else: +# msg = f"Missing dependency for {dependency.table} from {dep_from} until {run_date}" +# logger.info(msg) +# error = error + msg + "\n" +# +# if error != "": +# self.tracker.last_error = error +# return False +# +# return True +# +# def _get_completion(self, target: Table, info_dates: List[date], filters: Dict = None) -> List[bool]: +# """ +# Check if model has run for given dates +# +# :param target_path: Table object +# :param info_dates: list of dates +# :return: bool list +# """ +# if self.rerun: +# return [False for _ in info_dates] +# else: +# return self.check_dates_have_data(target, info_dates, filters) +# +# def _select_run_dates(self, pipeline: PipelineConfig, table: Table, filters: Dict = None) -> Tuple[List, List]: +# """ +# Select run dates and info dates based on completion +# +# :param pipeline: pipeline config +# :param table: table path +# :return: list of run dates and list of info dates +# """ +# possible_run_dates = date_manager.execution_dates(pipeline.schedule) +# possible_info_dates = [DateManager.to_partition_date(x, pipeline.schedule) for x in possible_run_dates] +# current_state = self._get_completion(table, possible_info_dates, filters) +# +# selection = [ +# (run, info) for run, info, state in zip(possible_run_dates, +# possible_info_dates, current_state) if not state +# ] +# +# if not len(selection): +# logger.info(f"{pipeline.name} has no dates to run") +# return [], [] +# +# selected_run_dates, selected_info_dates = zip(*selection) +# logger.info(f"{pipeline.name} identified to run for {selected_run_dates}") +# +# return list(selected_run_dates), list(selected_info_dates) +# +# def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: date, target: Table) -> int: +# """ +# Run one pipeline for one date +# +# :param pipeline: pipeline cfg +# :param run_date: run date +# :param info_date: information date +# :param target: target Table +# :return: success bool +# """ +# if self.skip_dependencies or self.check_dependencies(pipeline, run_date): +# logger.info(f"Running {pipeline.name} for {run_date}") +# +# feature_group = utils.load_module(pipeline.module) +# df = self._execute(feature_group, run_date, pipeline) +# self.writer.write(df, info_date, target) +# records = self._check_written(info_date, target, df, pipeline) +# logger.info(f"Generated {records} records") +# if records == 0: +# raise RuntimeError("No records generated") +# else: +# return records +# return 0 +# +# def _run_pipeline(self, pipeline: PipelineConfig): +# """ +# Run single pipeline for all required dates +# +# :param pipeline: pipeline cfg +# :return: success bool +# """ +# target = Table( +# schema_path=pipeline.target.target_schema, +# class_name=pipeline.module.python_class, +# partition=pipeline.target.target_partition_column, +# secondary_partitions=pipeline.target.secondary_partition_columns, +# table=pipeline.target.custom_name, +# ) +# logger.info(f"Loaded pipeline {pipeline.name}") +# +# selected_run_dates, selected_info_dates = self._select_run_dates( +# pipeline, target, pipeline.target.rerun_filters +# ) +# +# # ----------- Checking dependencies available ---------- +# for run_date, info_date in zip(selected_run_dates, selected_info_dates): +# run_start = datetime.datetime.now() +# try: +# records = self._run_one_date(pipeline, run_date, info_date, target) +# if records > 0: +# status = "Success" +# message = "" +# else: +# status = "Failure" +# message = self.tracker.last_error +# self.tracker.add( +# Record( +# job=pipeline.name, +# target=target.get_table_path(), +# date=info_date, +# time=datetime.datetime.now() - run_start, +# records=records, +# status=status, +# reason=message, +# ) +# ) +# except Exception as error: +# logger.error(f"An exception occurred in pipeline {pipeline.name}") +# logger.exception(error) +# self.tracker.add( +# Record( +# job=pipeline.name, +# target=target.get_table_path(), +# date=info_date, +# time=datetime.datetime.now() - run_start, +# records=0, +# status="Error", +# reason="Exception", +# exception=str(error), +# ) +# ) +# except KeyboardInterrupt: +# logger.error(f"Pipeline {pipeline.name} interrupted") +# self.tracker.add( +# Record( +# job=pipeline.name, +# target=target.get_table_path(), +# date=info_date, +# time=datetime.datetime.now() - run_start, +# records=0, +# status="Error", +# reason="Interrupted by user", +# ) +# ) +# raise KeyboardInterrupt +# +# def __call__(self): +# """Execute pipelines""" +# logger.info("Executing pipelines") +# try: +# if self.op: +# selected = [p for p in self.config.pipelines if p.name == self.op] +# if len(selected) < 1: +# raise ValueError(f"Unknown operation selected: {self.op}") +# self._run_pipeline(selected[0]) +# else: +# for pipeline in self.config.pipelines: +# self._run_pipeline(pipeline) +# finally: +# print(self.tracker.records) +# self.tracker.report_by_mail() +# logger.info("Execution finished") +# +# def debug(self) -> DataFrame: +# """Debug mode - run only first op for one date and return the resulting dataframe""" +# logger.info("Running in debug mode") +# if self.op: +# pipeline = [p for p in self.config.pipelines if p.name == self.op][0] +# else: +# pipeline = self.config.pipelines[0] +# +# target = Table( +# schema_path=pipeline.target.target_schema, +# class_name=pipeline.module.python_class, +# partition=pipeline.target.target_partition_column, +# secondary_partitions=pipeline.target.secondary_partition_columns, +# table=pipeline.target.custom_name, +# ) +# selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) +# if len(selected_run_dates) > 0: +# df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline) +# return self.writer._process(df, selected_info_dates[0], target) +# else: +# logger.info("No dates to run in debug mode") diff --git a/rialto/runner/table.py b/rialto/runner/table.py index 2d44498..8096130 100644 --- a/rialto/runner/table.py +++ b/rialto/runner/table.py @@ -14,7 +14,9 @@ __all__ = ["Table"] -from typing import List +from typing import Dict, List + +from config_loader import DependencyConfig, PipelineConfig from rialto.metadata import class_to_catalog_name @@ -22,6 +24,39 @@ class Table: """Handler for databricks catalog paths""" + @classmethod + def from_target_config(cls, config: PipelineConfig) -> "Table": + """ + Create table object from pipeline config target section + + :param config: Pipeline configuration + + :return: Table object + """ + return cls( + schema_path=config.target.target_schema, + class_name=config.module.python_class, + partition=config.target.target_partition_column, + secondary_partitions=config.target.secondary_partition_columns, + table=config.target.custom_name, + filters=config.target.filters, + ) + + @classmethod + def from_dependency_config(cls, config: DependencyConfig) -> "Table": + """ + Create table object from pipeline config dependency section + + :param config: Dependency configuration + + :return: Table object + """ + return cls( + table_path=config.table, + partition=config.date_col, + filters=config.filters, + ) + def __init__( self, catalog: str = None, @@ -32,12 +67,14 @@ def __init__( class_name: str = None, partition: str = None, secondary_partitions: List[str] = None, + filters: Dict = None, ): self.catalog = catalog self.schema = schema self.table = table self.partition = partition self.secondary_partitions = secondary_partitions + self.filters = filters if schema_path: schema_path = schema_path.split(".") self.catalog = schema_path[0] @@ -58,7 +95,7 @@ def get_table_path(self) -> str: """Get full table path""" return f"{self.catalog}.{self.schema}.{self.table}" - def get_all_partitions(self) -> List[str]: + def get_all_partition_columns(self) -> List[str]: """Get list of all partitions""" if self.secondary_partitions: return [self.partition] + self.secondary_partitions diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py index 5af1723..21a231c 100644 --- a/rialto/runner/utils.py +++ b/rialto/runner/utils.py @@ -12,19 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["load_module", "table_exists", "get_partitions", "init_tools", "find_dependency"] +__all__ = ["load_module", "init_tools", "find_dependency"] -from datetime import date from importlib import import_module -from typing import List, Tuple +from typing import Tuple from pyspark.sql import SparkSession -from rialto.common import DataReader from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner.config_loader import ModuleConfig, PipelineConfig -from rialto.runner.table import Table from rialto.runner.transformation import Transformation @@ -40,32 +37,6 @@ def load_module(cfg: ModuleConfig) -> Transformation: return class_obj() -def table_exists(spark: SparkSession, table: str) -> bool: - """ - Check table exists in spark catalog - - :param table: full table path - :return: bool - """ - return spark.catalog.tableExists(table) - - -def get_partitions(reader: DataReader, table: Table) -> List[date]: - """ - Get partition values - - :param table: Table object - :return: List of partition values - """ - rows = ( - reader.get_table(table.get_table_path(), date_column=table.partition) - .select(table.partition) - .distinct() - .collect() - ) - return [r[table.partition] for r in rows] - - def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]: """ Initialize metadata manager and feature loader diff --git a/rialto/runner/writer.py b/rialto/runner/writer.py index bc147fd..63d469b 100644 --- a/rialto/runner/writer.py +++ b/rialto/runner/writer.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["Writer"] +__all__ = ["DatabricksWriter", "Writer"] +from abc import ABC, abstractmethod from datetime import date from typing import List @@ -24,9 +25,25 @@ from rialto.runner.table import Table -class Writer: +class Writer(ABC): """Supporting class for runner""" + @abstractmethod + def write(self, df: DataFrame, info_date: date, table: Table) -> None: + """ + Write dataframe to storage + + :param df: dataframe to write + :param info_date: date to partition + :param table: path to write to + :return: None + """ + pass + + +class DatabricksWriter(Writer): + """Supporting class for runner, Databricks write operations""" + def __init__(self, spark: SparkSession, merge_schema=False): self.spark = spark self.merge_schema = merge_schema @@ -104,7 +121,7 @@ def write(self, df: DataFrame, info_date: date, table: Table) -> None: df = self._process(df, info_date, table) - replace_where = self._get_replace_condition(df, table.get_all_partitions()) + replace_where = self._get_replace_condition(df, table.get_all_partition_columns()) df.write.format("delta").partitionBy(table.partition).mode("overwrite").option( "mergeSchema", "true" if self.merge_schema else "false" diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py index 73b61b8..41409b0 100644 --- a/tests/runner/test_date_manager.py +++ b/tests/runner/test_date_manager.py @@ -61,7 +61,7 @@ def test_all_dates_reversed(): def test_run_dates_weekly(): cfg = ScheduleConfig(frequency="weekly", day=5) - run_dates = DateManager.run_dates( + run_dates = DateManager.execution_dates( date_from=DateManager.str_to_date("2023-02-05"), date_to=DateManager.str_to_date("2023-04-07"), schedule=cfg, @@ -85,7 +85,7 @@ def test_run_dates_weekly(): def test_run_dates_monthly(): cfg = ScheduleConfig(frequency="monthly", day=5) - run_dates = DateManager.run_dates( + run_dates = DateManager.execution_dates( date_from=DateManager.str_to_date("2022-08-05"), date_to=DateManager.str_to_date("2023-04-07"), schedule=cfg, @@ -109,7 +109,7 @@ def test_run_dates_monthly(): def test_run_dates_daily(): cfg = ScheduleConfig(frequency="daily") - run_dates = DateManager.run_dates( + run_dates = DateManager.execution_dates( date_from=DateManager.str_to_date("2023-03-28"), date_to=DateManager.str_to_date("2023-04-03"), schedule=cfg, @@ -131,7 +131,7 @@ def test_run_dates_daily(): def test_run_dates_invalid(): cfg = ScheduleConfig(frequency="random") with pytest.raises(ValueError) as exception: - DateManager.run_dates( + DateManager.execution_dates( date_from=DateManager.str_to_date("2023-03-28"), date_to=DateManager.str_to_date("2023-04-03"), schedule=cfg, @@ -146,7 +146,7 @@ def test_run_dates_invalid(): def test_to_info_date(shift, res): cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units="days", value=shift)]) base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) + info = DateManager.to_partition_date(base, cfg) assert DateManager.str_to_date(res) == info @@ -157,7 +157,7 @@ def test_to_info_date(shift, res): def test_info_date_shift_units(unit, result): cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units=unit, value=3)]) base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) + info = DateManager.to_partition_date(base, cfg) assert DateManager.str_to_date(result) == info @@ -167,5 +167,5 @@ def test_info_date_shift_combined(): info_date_shift=[IntervalConfig(units="months", value=3), IntervalConfig(units="days", value=4)], ) base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) + info = DateManager.to_partition_date(base, cfg) assert DateManager.str_to_date("2022-12-01") == info diff --git a/tests/runner/test_table.py b/tests/runner/test_table.py index f1e4ead..d718db1 100644 --- a/tests/runner/test_table.py +++ b/tests/runner/test_table.py @@ -31,13 +31,13 @@ def test_table_path_init(): def test_table_secondary_partitions(): t = Table(catalog="cat", schema="sch", table="tab", partition="part", secondary_partitions=["sec1", "sec2"]) - assert t.get_all_partitions() == ["part", "sec1", "sec2"] + assert t.get_all_partition_columns() == ["part", "sec1", "sec2"] def test_table_get_partitions_only_main(): t = Table(catalog="cat", schema="sch", table="tab", partition="part") - assert t.get_all_partitions() == ["part"] + assert t.get_all_partition_columns() == ["part"] def test_table_prioritize_table_name(): From 3d79e90be8f4472e7c73769fd90cf960aac3962e Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Tue, 19 May 2026 14:17:59 +0200 Subject: [PATCH 2/4] runner refactor --- rialto/metadata/utils.py | 2 +- rialto/runner/date_manager.py | 8 +- rialto/runner/execution_planner.py | 114 ++++++++++++++++++----------- rialto/runner/executor.py | 75 +++++++++++++++++++ rialto/runner/runner.py | 85 +++++++++------------ rialto/runner/table.py | 5 +- 6 files changed, 191 insertions(+), 98 deletions(-) create mode 100644 rialto/runner/executor.py diff --git a/rialto/metadata/utils.py b/rialto/metadata/utils.py index 0cb591c..efbbf76 100644 --- a/rialto/metadata/utils.py +++ b/rialto/metadata/utils.py @@ -17,7 +17,7 @@ def class_to_catalog_name(class_name) -> str: """ - Map python class name of feature group (CammelCase) to databricks compatible format (lowercase with underscores) + Map python class name of feature group (CamelCase) to databricks compatible format (lowercase with underscores) :param class_name: Python class name :return: feature storage name diff --git a/rialto/runner/date_manager.py b/rialto/runner/date_manager.py index e3cd043..a730766 100644 --- a/rialto/runner/date_manager.py +++ b/rialto/runner/date_manager.py @@ -17,17 +17,17 @@ from datetime import date, datetime from typing import List -from config_loader import PipelinesConfig from dateutil.relativedelta import relativedelta from loguru import logger -from rialto.runner.config_loader import ScheduleConfig +from rialto.runner.config_loader import PipelinesConfig, ScheduleConfig class DateManager: """Date generation and shifts based on configuration""" def __init__(self, config: PipelinesConfig, run_date: date = None): + self.config = config if run_date: run_date = DateManager.str_to_date(run_date) else: @@ -53,7 +53,7 @@ def get_date_until(self) -> date: """Get ending date of the execution window""" return self.date_until - def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[(date, date)]: + def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[tuple[date, date]]: """ Get list of execution and partition dates for given configuration @@ -114,7 +114,7 @@ def execution_dates(self, schedule: ScheduleConfig) -> List[date]: :param schedule: schedule config :return: list of dates """ - options = self.all_dates(self.date_from, self.date_to) + options = self.all_dates(self.date_from, self.date_until) if schedule.frequency == "daily": return options if schedule.frequency == "weekly": diff --git a/rialto/runner/execution_planner.py b/rialto/runner/execution_planner.py index 3725f2b..de3c324 100644 --- a/rialto/runner/execution_planner.py +++ b/rialto/runner/execution_planner.py @@ -15,11 +15,14 @@ from datetime import date from typing import List -from config_loader import PipelineConfig -from date_manager import DateManager from loguru import logger +from pyspark.sql import DataFrame, SparkSession +import rialto.runner.utils as utils +from rialto.common import DataReader +from rialto.runner.config_loader import PipelineConfig from rialto.runner.data_checker import DataChecker +from rialto.runner.date_manager import DateManager from rialto.runner.table import Table @@ -34,7 +37,7 @@ class Dependency: @dataclass -class Pipeline: +class Task: """Class representing a pipeline to be executed.""" op: str @@ -46,48 +49,18 @@ class Pipeline: completion: bool = False dependencies_complete: bool = False - def check_completion(self, checker: DataChecker, rerun: bool) -> None: - """ - Check if pipeline is complete by checking if target data exists for partition date - - :param checker: DataChecker instance to use for checking data presence - :param rerun: If True, skip completion check to allow re-running of completed pipelines - - :return: None, updates self.completion attribute - """ - if not rerun: - self.completion = checker.check_date(self.target, self.partition_date) - logger.info(f"Job {self.op} completion status for partition date {self.partition_date}: {self.completion}") - - def check_dependencies_complete(self, checker, skip_dependencies: bool) -> None: - """ - Check if dependencies are complete by checking if data exists for each dependency in date range - - :param checker: DataChecker instance to use for checking data presence - :param skip_dependencies: Skip dependency checks to allow running pipelines with incomplete dependencies - - :return: None, updates self.dependencies_complete attribute - """ - if not skip_dependencies: - for dependency in self.dependencies: - dependency.complete = checker.check_range(dependency.table, dependency.date_from, dependency.date_until) - logger.info( - f"Dependency {dependency.table.get_table_path()} completion status for date range " - f"{dependency.date_from} - {dependency.date_until}: {dependency.complete}" - ) - self.dependencies_complete = all([dependency.complete for dependency in self.dependencies]) - class ExecutionPlanner: """Planner for pipeline execution, managing tasks and their dependencies""" - def __init__(self, date_manager: DateManager): + def __init__(self, spark: SparkSession, date_manager: DateManager): + self.spark = spark self.date_manager = date_manager self.tasks = [] - def add_pipeline(self, name: str, execution_date: date, partition_date: date, config: PipelineConfig) -> None: + def add_task(self, name: str, execution_date: date, partition_date: date, config: PipelineConfig) -> None: """ - Add pipeline to execution plan + Add task to execution plan :param name: Name of the pipeline :param execution_date: Date when the pipeline is scheduled to run @@ -97,7 +70,7 @@ def add_pipeline(self, name: str, execution_date: date, partition_date: date, co :return: None, adds a Pipeline object to self.tasks """ target = Table.from_target_config(config) - new_pipe = Pipeline( + new_pipe = Task( op=name, execution_date=execution_date, partition_date=partition_date, config=config, target=target ) @@ -119,9 +92,68 @@ def log_status(self) -> None: """Log status of all tasks in execution plan, showing completion and dependency status""" check = "\u2714" # ✔ cross = "\u2718" # ✘ - logger.info(f"{'Job Name':<25} {'Partition Date':<15} {'Complete':<10} {'Deps Complete':<15}") - logger.info("-" * 70) + status = f"\n{'Job Name':<50} {'Partition Date':<15} {'Complete':<8} {'Dependencies':<12}\n" + status = status + ("-" * 70 + "\n") for task in self.tasks: complete_icon = check if task.completion else cross deps_icon = check if task.dependencies_complete else cross - logger.info(f"{task.op:<25} {str(task.partition_date):<15} {complete_icon:^10} {deps_icon:^15}") + status = status + f"{task.op:<50} {str(task.partition_date):<15} {complete_icon:^8} {deps_icon:^12}\n" + logger.info(status) + + def check_completion(self, pipeline: Task, checker: DataChecker, rerun: bool) -> None: + """ + Check if pipeline is complete by checking if target data exists for partition date + + :param pipeline: Pipeline object for which to check completion + :param checker: DataChecker instance to use for checking data presence + :param rerun: If True, skip completion check to allow re-running of completed pipelines + + :return: None, updates self.completion attribute + """ + if not rerun: + pipeline.completion = checker.check_date(pipeline.target, pipeline.partition_date) + logger.info( + f"Job {pipeline.op} completion status for partition date " + f"{pipeline.partition_date}: {pipeline.completion}" + ) + + def check_pipeline_dependencies(self, pipeline: Task, checker: DataChecker, skip_dependencies: bool) -> None: + """ + Check if dependencies are complete by checking if data exists for each dependency in date range + + :param pipeline: Pipeline object for which to check dependencies + :param checker: DataChecker instance to use for checking data presence + :param skip_dependencies: Skip dependency checks to allow running pipelines with incomplete dependencies + + :return: None, updates self.dependencies_complete attribute + """ + if not skip_dependencies: + for dependency in pipeline.dependencies: + dependency.complete = checker.check_range(dependency.table, dependency.date_from, dependency.date_until) + logger.info( + f"Dependency {dependency.table.get_table_path()} completion status for date range " + f"{dependency.date_from} - {dependency.date_until}: {dependency.complete}" + ) + pipeline.dependencies_complete = all([dependency.complete for dependency in pipeline.dependencies]) + + def execute_pipeline(self, pipeline: Task, reader: DataReader) -> DataFrame: + """ + Execute the pipeline, assuming all dependencies are complete and pipeline is not already complete + + :param pipeline: Pipeline object to execute + :param reader: DataReader instance to use for reading data + + :return: DataFrame output from pipeline execution + """ + logger.info(f"Executing pipeline {pipeline.op} for partition date {pipeline.partition_date}") + job = utils.load_module(pipeline.config.module) + metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline.config) + df = job.run( + spark=self.spark, + run_date=pipeline.execution_date, + config=pipeline.config, + reader=reader, + metadata_manager=metadata_manager, + feature_loader=feature_loader, + ) + return df diff --git a/rialto/runner/executor.py b/rialto/runner/executor.py new file mode 100644 index 0000000..7d97ea9 --- /dev/null +++ b/rialto/runner/executor.py @@ -0,0 +1,75 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime + +from loguru import logger +from pyspark.sql import SparkSession + +import rialto.runner.utils as utils +from rialto.common import DataReader +from rialto.runner.data_checker import DataChecker +from rialto.runner.execution_planner import Task +from rialto.runner.reporting.record import Record +from rialto.runner.reporting.tracker import Tracker +from rialto.runner.writer import Writer + + +class PipelineExecutor: + """Executes a single pipeline task.""" + + def __init__(self, spark: SparkSession, reader: DataReader, writer: Writer, checker: DataChecker, tracker: Tracker): + self.spark = spark + self.reader = reader + self.writer = writer + self.checker = checker + self.tracker = tracker + + def execute(self, pipeline: Task): + """ + Execute the pipeline task. + + :param pipeline: Pipeline object to execute. + :return: None + """ + logger.info(f"Executing pipeline {pipeline.op} for partition date {pipeline.partition_date}") + run_start = datetime.now() + + # Load and run the job + job = utils.load_module(pipeline.config.module) + metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline.config) + df = job.run( + spark=self.spark, + run_date=pipeline.execution_date, + config=pipeline.config, + reader=self.reader, + metadata_manager=metadata_manager, + feature_loader=feature_loader, + ) + + # Write the output + self.writer.write(df, pipeline.partition_date, pipeline.target) + + # Perform checks and track results + records = self.checker.check_written(pipeline.target, pipeline.partition_date, df) + self.tracker.add( + Record( + job=pipeline.op, + target=pipeline.target.get_table_path(), + date=pipeline.partition_date, + time=datetime.now() - run_start, + records=records, + status="status", + reason="message", + ) + ) diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index fb743c0..381bfcb 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -14,19 +14,16 @@ __all__ = ["Runner"] -from datetime import datetime from typing import Dict, List -from execution_planner import ExecutionPlanner -from loguru import logger from pyspark.sql import DataFrame, SparkSession -import rialto.runner.utils as utils from rialto.common import TableReader from rialto.runner.config_loader import ConfigLoader, PipelineConfig from rialto.runner.data_checker import DataChecker from rialto.runner.date_manager import DateManager -from rialto.runner.reporting.record import Record +from rialto.runner.execution_planner import ExecutionPlanner +from rialto.runner.executor import PipelineExecutor from rialto.runner.reporting.tracker import Tracker from rialto.runner.writer import DatabricksWriter @@ -57,7 +54,14 @@ def __init__( mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark ) self.date_manager = DateManager(self.config, run_date) - self.planner = ExecutionPlanner() + self.planner = ExecutionPlanner(spark, date_manager=self.date_manager) + self.executor = PipelineExecutor( + spark=self.spark, + reader=self.reader, + writer=self.writer, + checker=self.checker, + tracker=self.tracker, + ) def _select_pipelines(self) -> List[PipelineConfig]: """Select pipelines to run based on config and input parameters""" @@ -69,56 +73,39 @@ def _select_pipelines(self) -> List[PipelineConfig]: else: return self.config.pipelines - def __call__(self): - """Execute pipelines""" - pipelines = self._select_pipelines() - - # Register pipelines in execution planner with their execution and partition dates + def _register_tasks(self, pipelines: List[PipelineConfig]) -> None: for pipeline in pipelines: - exec_date, partition_date = self.date_manager.get_execution_and_partition_dates(pipeline.schedule) - self.planner.add_pipeline( - name=pipeline.name, execution_date=exec_date, partition_date=partition_date, config=pipeline - ) + for exec_date, partition_date in self.date_manager.get_execution_and_partition_dates(pipeline.schedule): + self.planner.add_task( + name=pipeline.name, execution_date=exec_date, partition_date=partition_date, config=pipeline + ) + def _check_tasks(self) -> None: for task in self.planner.tasks: - task.check_completion(self.checker, self.rerun) - task.check_dependencies_complete(self.checker, self.skip_dependencies) - - self.planner.log_status() + self.planner.check_completion(task, self.checker, self.rerun) + self.planner.check_pipeline_dependencies(task, self.checker, self.skip_dependencies) - # TODO everything bellow is just temporary + def _run_tasks(self) -> None: for task in self.planner.tasks: if not task.completion and task.dependencies_complete: - logger.info(f"Running pipeline {task.op} for partition date {task.partition_date}") - job = utils.load_module(task.config.module) - metadata_manager, feature_loader = utils.init_tools(self.spark, task.config) - run_start = datetime.now() - df = job.run( - spark=self.spark, - run_date=task.execution_date, - config=task.config, - reader=self.reader, - metadata_manager=metadata_manager, - feature_loader=feature_loader, - ) - self.writer.write(df, task.partition_date, task.target) - records = self.checker.check_written(task.target, task.package, df) - - self.tracker.add( - Record( - job=task.op, - target=task.target.get_table_path(), - date=task.partition_date, - time=datetime.now() - run_start, - records=records, - status="status", - reason="message", - ) - ) + self.executor.execute(task) - # 6. run the pipeline for dates with completed dependencies - # 7. write results - # 8. sumbit tracking + def __call__(self): + """Execute pipelines""" + pipelines = self._select_pipelines() + + self._register_tasks(pipelines) + self._check_tasks() + self.planner.log_status() + self._run_tasks() + + def dry_run(self): + """Dry run - log status of pipelines without executing""" + pipelines = self._select_pipelines() + + self._register_tasks(pipelines) + self._check_tasks() + self.planner.log_status() def debug(self) -> DataFrame: """Debug mode - run only first op for one date and return the resulting dataframe""" diff --git a/rialto/runner/table.py b/rialto/runner/table.py index 8096130..f97d001 100644 --- a/rialto/runner/table.py +++ b/rialto/runner/table.py @@ -16,9 +16,8 @@ from typing import Dict, List -from config_loader import DependencyConfig, PipelineConfig - from rialto.metadata import class_to_catalog_name +from rialto.runner.config_loader import DependencyConfig, PipelineConfig class Table: @@ -39,7 +38,7 @@ def from_target_config(cls, config: PipelineConfig) -> "Table": partition=config.target.target_partition_column, secondary_partitions=config.target.secondary_partition_columns, table=config.target.custom_name, - filters=config.target.filters, + filters=config.target.rerun_filters, ) @classmethod From cd274ee68336fb8f46eb4b46c4a1872acdefd41f Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Tue, 19 May 2026 23:37:50 +0200 Subject: [PATCH 3/4] data checker tests --- rialto/common/table_reader.py | 10 + rialto/runner/data_checker.py | 61 +++-- rialto/runner/date_manager.py | 2 +- ..._overrides.py => test_config_overrides.py} | 0 tests/runner/test_data_checker.py | 230 ++++++++++++++++++ 5 files changed, 285 insertions(+), 18 deletions(-) rename tests/runner/{test_overrides.py => test_config_overrides.py} (100%) create mode 100644 tests/runner/test_data_checker.py diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py index fe910c8..0d8645c 100644 --- a/rialto/common/table_reader.py +++ b/rialto/common/table_reader.py @@ -71,6 +71,16 @@ def get_table( """ raise NotImplementedError + @abc.abstractmethod + def table_exists(self, table: str) -> bool: + """ + Check table exists in storage + + :param table: full table path + :return: bool + """ + raise NotImplementedError + class TableReader(DataReader): """An implementation of data reader for databricks tables""" diff --git a/rialto/runner/data_checker.py b/rialto/runner/data_checker.py index 1af1c31..787c56a 100644 --- a/rialto/runner/data_checker.py +++ b/rialto/runner/data_checker.py @@ -11,27 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +__all__ = ["DataChecker"] + from datetime import date +from typing import Dict from loguru import logger from pyspark.sql import DataFrame -from rialto.common import TableReader +from rialto.common import DataReader from rialto.runner.table import Table class DataChecker: """Checks if data for given date or date range is present in storage""" - def __init__(self, reader: TableReader): + def __init__(self, reader: DataReader): self.reader = reader def check_date(self, target: Table, partition_date: date) -> bool: - """Check if data for given date is present in target""" + """Check if data for given date is present in target + + :param target: target Table to check + :param partition_date: Date to check + :return: True if data for given date is present, False otherwise + """ return self.check_range(target, partition_date, partition_date) def check_range(self, target: Table, start_date: date, end_date: date) -> bool: - """Check if data for given date range is present in target""" + """Check if data for given date range is present in target + + :param target: target Table to check + :param start_date: Starting date of the range to check + :param end_date: Ending date of the range to check + :return: True if data for given date range is present, False otherwise + """ if self.reader.table_exists(target.get_table_path()): df = self.reader.get_table( target.get_table_path(), @@ -41,30 +55,43 @@ def check_range(self, target: Table, start_date: date, end_date: date) -> bool: filters=target.filters, ) data_exists = df.count() > 0 - if data_exists and target.filters is None and target.secondary_partitions is not None: - # dependencies don't have secondary partitions, this is skipped - logger.info( + if ( + data_exists + and (target.filters is None or target.filters == {}) + and target.secondary_partitions is not None + ): + logger.warning( f"Overwriting {target.get_table_path()} completion status for {start_date} due to presence of " f"secondary partitions and no filters." ) data_exists = False return data_exists else: - logger.info(f"Table {target.get_table_path()} doesn't exist!") + logger.warning(f"Target table {target.get_table_path()} doesn't exist yet.") return False - def check_written(self, target: Table, partition_date: date, df: DataFrame) -> int: - """Check how many records were written""" - filters = {} + def _get_filters(self, target: Table, df: DataFrame) -> Dict: if target.filters is not None: - filters = target.filters + return target.filters + elif target.secondary_partitions: + filters = {} + logger.info("Inferring target sub-partition values from generated data.") + row = df.select(*target.secondary_partitions).distinct().collect()[0] + for c in target.secondary_partitions: + filters[c] = row[c] + return filters else: - if target.secondary_partitions: - row = df.select(*target.secondary_partitions).distinct().collect()[0] - for c in target.secondary_partitions: - val = row[0][c] - filters[c] = val + return {} + + def check_written(self, target: Table, partition_date: date, df: DataFrame) -> int: + """Check how many records were written + :param target: target Table to check + :param partition_date: Date to check + :param df: DataFrame that was written, used to determine filters if not provided in config + :return: Number of records for given date + """ + filters = self._get_filters(target, df) df = self.reader.get_table( target.get_table_path(), date_column=target.partition, diff --git a/rialto/runner/date_manager.py b/rialto/runner/date_manager.py index a730766..d12c983 100644 --- a/rialto/runner/date_manager.py +++ b/rialto/runner/date_manager.py @@ -29,7 +29,7 @@ class DateManager: def __init__(self, config: PipelinesConfig, run_date: date = None): self.config = config if run_date: - run_date = DateManager.str_to_date(run_date) + run_date = self.str_to_date(run_date) else: run_date = date.today() diff --git a/tests/runner/test_overrides.py b/tests/runner/test_config_overrides.py similarity index 100% rename from tests/runner/test_overrides.py rename to tests/runner/test_config_overrides.py diff --git a/tests/runner/test_data_checker.py b/tests/runner/test_data_checker.py new file mode 100644 index 0000000..0ce98da --- /dev/null +++ b/tests/runner/test_data_checker.py @@ -0,0 +1,230 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date +from unittest.mock import MagicMock + +import pytest +from pyspark.sql.types import DateType, IntegerType, StringType, StructField, StructType + +from rialto.common import TableReader +from rialto.runner.data_checker import DataChecker +from rialto.runner.table import Table + + +@pytest.fixture(scope="module") +def simple_dataframe(spark): + df = [ + ("A", date(2023, 3, 5)), + ("B", date(2023, 3, 12)), + ("C", date(2023, 3, 19)), + ] + schema = StructType([StructField("KEY", StringType(), True), StructField("DATE", DateType(), True)]) + return spark.createDataFrame(df, schema=schema) + + +@pytest.fixture(scope="module") +def partitioned_dataframe(spark): + df = [ + ("W", 1, "A", date(2023, 3, 5)), + ("E", 1, "B", date(2023, 3, 5)), + ("R", 2, "B", date(2023, 3, 5)), + ("T", 1, "B", date(2023, 3, 12)), + ("Y", 2, "A", date(2023, 3, 19)), + ] + schema = StructType( + [ + StructField("VALUE", StringType(), True), + StructField("VERSION", IntegerType(), True), + StructField("TYPE", StringType(), True), + StructField("DATE", DateType(), True), + ] + ) + return spark.createDataFrame(df, schema=schema) + + +@pytest.fixture(scope="module") +def new_insert_partitioned_dataframe(spark): + df = [ + ("E", 1, "B", date(2023, 3, 5)), + ("T", 1, "B", date(2023, 3, 5)), + ] + schema = StructType( + [ + StructField("VALUE", StringType(), True), + StructField("VERSION", IntegerType(), True), + StructField("TYPE", StringType(), True), + StructField("DATE", DateType(), True), + ] + ) + return spark.createDataFrame(df, schema=schema) + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 3, 12), True), + (date(2023, 3, 10), False), + (date(2023, 3, 19), True), + (date(2023, 3, 26), False), + ], +) +def test_check_date(mocker, spark, simple_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_date(table, partition_date) + assert result == expected + + +@pytest.mark.parametrize( + "start_date, end_date, expected", + [ + (date(2023, 3, 12), date(2023, 4, 12), True), + (date(2023, 3, 10), date(2023, 3, 11), False), + (date(2023, 3, 19), date(2023, 3, 19), True), + (date(2023, 3, 26), date(2023, 3, 29), False), + ], +) +def test_check_range(mocker, spark, simple_dataframe, start_date, end_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_range(table, start_date, end_date) + assert result == expected + + +def test_check_range_no_table( + mocker, + spark, +): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=False) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_date(table, date(2023, 3, 12)) + assert result is False + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 2, 26), False), + (date(2023, 3, 5), True), + (date(2023, 3, 12), False), + (date(2023, 3, 19), False), + (date(2023, 3, 26), False), + ], +) +def test_check_date_secondary_partitions_and_filters(mocker, spark, partitioned_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table( + table_path="catalog.schema.simple_group", + partition="DATE", + secondary_partitions=["VERSION", "TYPE"], + filters={"version": 1, "type": "A"}, + ) + result = data_checker.check_date(table, partition_date) + assert result == expected + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 2, 26), False), + (date(2023, 3, 5), False), + (date(2023, 3, 12), False), + (date(2023, 3, 19), False), + (date(2023, 3, 26), False), + ], +) +def test_check_date_secondary_partitions_no_filters(mocker, spark, partitioned_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table( + table_path="catalog.schema.simple_group", + partition="DATE", + secondary_partitions=["VERSION", "TYPE"], + filters=None, + ) + result = data_checker.check_date(table, partition_date) + assert result == expected + + +def test_check_written_with_no_filters_or_secondary_partitions(): + mock_reader = MagicMock() + mock_df = MagicMock() + mock_reader.get_table.return_value = mock_df + mock_df.count.return_value = 42 + + checker = DataChecker(mock_reader) + table = Table(table_path="dummy.table.path", partition="DATE") + result = checker.check_written(table, date(2023, 3, 5), MagicMock()) + assert result == 42 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={}, + ) + + +def test_check_written_with_filters(): + mock_reader = MagicMock() + mock_df = MagicMock() + mock_reader.get_table.return_value = mock_df + mock_df.count.return_value = 42 + + checker = DataChecker(mock_reader) + table = Table(table_path="dummy.table.path", partition="DATE", filters={"foo": "bar"}) + result = checker.check_written(table, date(2023, 3, 5), MagicMock()) + assert result == 42 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={"foo": "bar"}, + ) + + +def test_check_written_with_secondary_partitions(mocker, new_insert_partitioned_dataframe): + # Setup + mock_reader = MagicMock() + mock_df = MagicMock() + mock_df.count.return_value = 7 + mock_reader.get_table.return_value = mock_df + + checker = DataChecker(mock_reader) + table = Table( + table_path="dummy.table.path", partition="DATE", filters=None, secondary_partitions=["VERSION", "TYPE"] + ) + result = checker.check_written(table, date(2023, 3, 5), new_insert_partitioned_dataframe) + assert result == 7 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={"VERSION": 1, "TYPE": "B"}, + ) From 6e7502c25715715540d054affa79c3bd3322ccb9 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Wed, 20 May 2026 01:40:35 +0200 Subject: [PATCH 4/4] cleanup --- rialto/runner/config_loader.py | 18 ++++---- rialto/runner/date_manager.py | 70 +++++++++++++++++------------- rialto/runner/execution_planner.py | 33 +++----------- rialto/runner/executor.py | 34 ++++----------- rialto/runner/runner.py | 11 +++-- 5 files changed, 69 insertions(+), 97 deletions(-) diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index a3ae621..1aa5fba 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -16,9 +16,9 @@ "ConfigLoader", ] -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from rialto.common.utils import load_yaml from rialto.runner.config_overrides import override_config @@ -35,8 +35,10 @@ class IntervalConfig(BaseConfig): class ScheduleConfig(BaseConfig): frequency: str - day: Optional[int] = 0 - info_date_shift: Optional[List[IntervalConfig]] = IntervalConfig(units="days", value=0) + day: Optional[Union[int, str]] = 0 + info_date_shift: Optional[Union[IntervalConfig, List[IntervalConfig]]] = Field( + default_factory=lambda: IntervalConfig(units="days", value=0) + ) class DependencyConfig(BaseConfig): @@ -88,16 +90,16 @@ class PipelineConfig(BaseConfig): name: str module: ModuleConfig schedule: ScheduleConfig - dependencies: Optional[List[DependencyConfig]] = [] - target: TargetConfig = None + dependencies: Optional[List[DependencyConfig]] = Field(default_factory=list) + target: Optional[TargetConfig] = None metadata_manager: Optional[MetadataManagerConfig] = None feature_loader: Optional[FeatureLoaderConfig] = None - extras: Optional[Dict] = {} + extras: Optional[Dict] = Field(default_factory=dict) class PipelinesConfig(BaseConfig): runner: RunnerConfig - pipelines: list[PipelineConfig] + pipelines: List[PipelineConfig] def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: diff --git a/rialto/runner/date_manager.py b/rialto/runner/date_manager.py index d12c983..394916d 100644 --- a/rialto/runner/date_manager.py +++ b/rialto/runner/date_manager.py @@ -27,14 +27,13 @@ class DateManager: """Date generation and shifts based on configuration""" def __init__(self, config: PipelinesConfig, run_date: date = None): - self.config = config if run_date: run_date = self.str_to_date(run_date) else: run_date = date.today() self.date_from = self.date_subtract( - run_date=run_date, + input_date=run_date, units=self.config.runner.watched_period_units, value=self.config.runner.watched_period_value, ) @@ -53,19 +52,7 @@ def get_date_until(self) -> date: """Get ending date of the execution window""" return self.date_until - def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[tuple[date, date]]: - """ - Get list of execution and partition dates for given configuration - - :return: List of tuples with execution and partition dates - """ - datepairs = [] - execution = self.execution_dates(schedule) - for ex_date in execution: - partition = self.to_partition_date(ex_date, schedule) - datepairs.append((ex_date, partition)) - return datepairs - + @staticmethod def str_to_date(self, str_date: str) -> date: """ Convert YYYY-MM-DD string to date @@ -73,27 +60,32 @@ def str_to_date(self, str_date: str) -> date: :param str_date: string date :return: date """ - return datetime.strptime(str_date, "%Y-%m-%d").date() + try: + return datetime.strptime(str_date, "%Y-%m-%d").date() + except ValueError: + raise ValueError(f"Invalid date format: {str_date}. Expected YYYY-MM-DD.") - def date_subtract(self, run_date: date, units: str, value: int) -> date: + @staticmethod + def date_subtract(self, input_date: date, units: str, value: int) -> date: """ - Generate starting date from given date and config + Subtract given number of units from input date - :param run_date: base date + :param input_date: base date :param units: units: years, months, weeks, days :param value: number of units to subtract :return: Starting date """ if units == "years": - return run_date - relativedelta(years=value) + return input_date - relativedelta(years=value) if units == "months": - return run_date - relativedelta(months=value) + return input_date - relativedelta(months=value) if units == "weeks": - return run_date - relativedelta(weeks=value) + return input_date - relativedelta(weeks=value) if units == "days": - return run_date - relativedelta(days=value) + return input_date - relativedelta(days=value) raise ValueError(f"Unknown time unit {units}") + @staticmethod def all_dates(self, date_from: date, date_to: date) -> List[date]: """ Get list of all dates between, inclusive @@ -107,23 +99,39 @@ def all_dates(self, date_from: date, date_to: date) -> List[date]: return [date_from + relativedelta(days=n) for n in range((date_to - date_from).days + 1)] - def execution_dates(self, schedule: ScheduleConfig) -> List[date]: + def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[tuple[date, date]]: + """ + Get list of execution and partition dates for given configuration + + :return: List of tuples with execution and partition dates + """ + execution_dates = self._execution_dates(schedule) + return [(ex_date, self._to_partition_date(ex_date, schedule)) for ex_date in execution_dates] + + def _execution_dates(self, schedule: ScheduleConfig) -> List[date]: """ Select dates inside given interval depending on frequency and selected day :param schedule: schedule config - :return: list of dates + :return: List of execution dates """ options = self.all_dates(self.date_from, self.date_until) - if schedule.frequency == "daily": + frequency = schedule.frequency.lower() + if frequency == "daily": return options - if schedule.frequency == "weekly": + if frequency == "weekly": + if not (1 <= schedule.day <= 7): + raise ValueError(f"Invalid day for weekly frequency: {schedule.day}. Must be 1-7.") return [x for x in options if x.isoweekday() == schedule.day] - if schedule.frequency == "monthly": + if frequency == "monthly": + if schedule.day == "last": + return [x for x in options if (x + relativedelta(days=1)).month != x.month] + if not (1 <= schedule.day <= 31): + raise ValueError(f"Invalid day for monthly frequency: {schedule.day}. Must be 1-31 or last.") return [x for x in options if x.day == schedule.day] - raise ValueError(f"Unknown frequency {schedule.frequency}") + raise ValueError(f"Unknown frequency: {schedule.frequency}") - def to_partition_date(self, date: date, schedule: ScheduleConfig) -> date: + def _to_partition_date(self, date: date, schedule: ScheduleConfig) -> date: """ Shift given date according to config @@ -131,7 +139,7 @@ def to_partition_date(self, date: date, schedule: ScheduleConfig) -> date: :param schedule: schedule config :return: date """ - if isinstance(schedule.info_date_shift, List): + if isinstance(schedule.info_date_shift, list): for shift in schedule.info_date_shift: date = self.date_subtract(date, units=shift.units, value=shift.value) else: diff --git a/rialto/runner/execution_planner.py b/rialto/runner/execution_planner.py index de3c324..5c8de3b 100644 --- a/rialto/runner/execution_planner.py +++ b/rialto/runner/execution_planner.py @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass, field from datetime import date -from typing import List +from typing import Iterator, List from loguru import logger -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession -import rialto.runner.utils as utils -from rialto.common import DataReader from rialto.runner.config_loader import PipelineConfig from rialto.runner.data_checker import DataChecker from rialto.runner.date_manager import DateManager @@ -45,7 +44,7 @@ class Task: partition_date: date config: PipelineConfig target: Table - dependencies: List = field(default_factory=list) + dependencies: List[Dependency] = field(default_factory=list) completion: bool = False dependencies_complete: bool = False @@ -84,7 +83,7 @@ def add_task(self, name: str, execution_date: date, partition_date: date, config self.tasks.append(new_pipe) - def __iter__(self): + def __iter__(self) -> Iterator[Task]: """Allow iteration over tasks in execution plan""" return iter(self.tasks) @@ -135,25 +134,3 @@ def check_pipeline_dependencies(self, pipeline: Task, checker: DataChecker, skip f"{dependency.date_from} - {dependency.date_until}: {dependency.complete}" ) pipeline.dependencies_complete = all([dependency.complete for dependency in pipeline.dependencies]) - - def execute_pipeline(self, pipeline: Task, reader: DataReader) -> DataFrame: - """ - Execute the pipeline, assuming all dependencies are complete and pipeline is not already complete - - :param pipeline: Pipeline object to execute - :param reader: DataReader instance to use for reading data - - :return: DataFrame output from pipeline execution - """ - logger.info(f"Executing pipeline {pipeline.op} for partition date {pipeline.partition_date}") - job = utils.load_module(pipeline.config.module) - metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline.config) - df = job.run( - spark=self.spark, - run_date=pipeline.execution_date, - config=pipeline.config, - reader=reader, - metadata_manager=metadata_manager, - feature_loader=feature_loader, - ) - return df diff --git a/rialto/runner/executor.py b/rialto/runner/executor.py index 7d97ea9..7fdbc5e 100644 --- a/rialto/runner/executor.py +++ b/rialto/runner/executor.py @@ -11,39 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime + +__all__ = ["PipelineExecutor"] from loguru import logger -from pyspark.sql import SparkSession +from pyspark.sql import DataFrame, SparkSession import rialto.runner.utils as utils from rialto.common import DataReader from rialto.runner.data_checker import DataChecker from rialto.runner.execution_planner import Task -from rialto.runner.reporting.record import Record from rialto.runner.reporting.tracker import Tracker -from rialto.runner.writer import Writer class PipelineExecutor: """Executes a single pipeline task.""" - def __init__(self, spark: SparkSession, reader: DataReader, writer: Writer, checker: DataChecker, tracker: Tracker): + def __init__(self, spark: SparkSession, reader: DataReader, checker: DataChecker, tracker: Tracker): self.spark = spark self.reader = reader - self.writer = writer self.checker = checker self.tracker = tracker - def execute(self, pipeline: Task): + @logger.catch + def execute(self, pipeline: Task) -> DataFrame: """ Execute the pipeline task. :param pipeline: Pipeline object to execute. - :return: None + :return: DataFrame resulting from pipeline execution. """ logger.info(f"Executing pipeline {pipeline.op} for partition date {pipeline.partition_date}") - run_start = datetime.now() # Load and run the job job = utils.load_module(pipeline.config.module) @@ -56,20 +54,4 @@ def execute(self, pipeline: Task): metadata_manager=metadata_manager, feature_loader=feature_loader, ) - - # Write the output - self.writer.write(df, pipeline.partition_date, pipeline.target) - - # Perform checks and track results - records = self.checker.check_written(pipeline.target, pipeline.partition_date, df) - self.tracker.add( - Record( - job=pipeline.op, - target=pipeline.target.get_table_path(), - date=pipeline.partition_date, - time=datetime.now() - run_start, - records=records, - status="status", - reason="message", - ) - ) + return df diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 381bfcb..604b88a 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -58,7 +58,6 @@ def __init__( self.executor = PipelineExecutor( spark=self.spark, reader=self.reader, - writer=self.writer, checker=self.checker, tracker=self.tracker, ) @@ -88,12 +87,14 @@ def _check_tasks(self) -> None: def _run_tasks(self) -> None: for task in self.planner.tasks: if not task.completion and task.dependencies_complete: - self.executor.execute(task) + # run_start = datetime.now() + df = self.executor.execute(task) + self.writer.write(df, task.partition_date, task.target) + # records = self.checker.check_written(task.target, task.partition_date, df) def __call__(self): """Execute pipelines""" pipelines = self._select_pipelines() - self._register_tasks(pipelines) self._check_tasks() self.planner.log_status() @@ -102,10 +103,12 @@ def __call__(self): def dry_run(self): """Dry run - log status of pipelines without executing""" pipelines = self._select_pipelines() - self._register_tasks(pipelines) self._check_tasks() self.planner.log_status() def debug(self) -> DataFrame: """Debug mode - run only first op for one date and return the resulting dataframe""" + pipelines = self._select_pipelines() + self._register_tasks(pipelines) + return self.executor.execute(self.planner.tasks[0])