Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions rialto/common/table_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def get_table(
"""
raise NotImplementedError

@abc.abstractmethod
def table_exists(self, table: str) -> bool:
"""
Check table exists in storage

:param table: full table path
:return: bool
"""
raise NotImplementedError


class TableReader(DataReader):
"""An implementation of data reader for databricks tables"""
Expand Down Expand Up @@ -165,3 +175,12 @@ def get_table(
if uppercase_columns:
df = self._uppercase_column_names(df)
return df

def table_exists(self, table: str) -> bool:
"""
Check table exists in spark catalog

:param table: full table path
:return: bool
"""
return self.spark.catalog.tableExists(table)
2 changes: 1 addition & 1 deletion rialto/metadata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def class_to_catalog_name(class_name) -> str:
"""
Map python class name of feature group (CammelCase) to databricks compatible format (lowercase with underscores)
Map python class name of feature group (CamelCase) to databricks compatible format (lowercase with underscores)

:param class_name: Python class name
:return: feature storage name
Expand Down
29 changes: 20 additions & 9 deletions rialto/runner/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

__all__ = [
"get_pipelines_config",
"ConfigLoader",
]

from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from rialto.common.utils import load_yaml
from rialto.runner.config_overrides import override_config
Expand All @@ -35,8 +35,10 @@ class IntervalConfig(BaseConfig):

class ScheduleConfig(BaseConfig):
frequency: str
day: Optional[int] = 0
info_date_shift: Optional[List[IntervalConfig]] = IntervalConfig(units="days", value=0)
day: Optional[Union[int, str]] = 0
info_date_shift: Optional[Union[IntervalConfig, List[IntervalConfig]]] = Field(
default_factory=lambda: IntervalConfig(units="days", value=0)
)


class DependencyConfig(BaseConfig):
Expand Down Expand Up @@ -88,16 +90,16 @@ class PipelineConfig(BaseConfig):
name: str
module: ModuleConfig
schedule: ScheduleConfig
dependencies: Optional[List[DependencyConfig]] = []
target: TargetConfig = None
dependencies: Optional[List[DependencyConfig]] = Field(default_factory=list)
target: Optional[TargetConfig] = None
metadata_manager: Optional[MetadataManagerConfig] = None
feature_loader: Optional[FeatureLoaderConfig] = None
extras: Optional[Dict] = {}
extras: Optional[Dict] = Field(default_factory=dict)


class PipelinesConfig(BaseConfig):
runner: RunnerConfig
pipelines: list[PipelineConfig]
pipelines: List[PipelineConfig]


def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig:
Expand All @@ -108,3 +110,12 @@ def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig:
return PipelinesConfig(**cfg)
else:
return PipelinesConfig(**raw_config)


class ConfigLoader:
"""Loader for pipelines config"""

@staticmethod
def load_yaml(path: str, overrides: Dict) -> PipelinesConfig:
"""Load yaml config and apply overrides"""
return get_pipelines_config(path, overrides)
103 changes: 103 additions & 0 deletions rialto/runner/data_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["DataChecker"]

from datetime import date
from typing import Dict

from loguru import logger
from pyspark.sql import DataFrame

from rialto.common import DataReader
from rialto.runner.table import Table


class DataChecker:
"""Checks if data for given date or date range is present in storage"""

def __init__(self, reader: DataReader):
self.reader = reader

def check_date(self, target: Table, partition_date: date) -> bool:
"""Check if data for given date is present in target

:param target: target Table to check
:param partition_date: Date to check
:return: True if data for given date is present, False otherwise
"""
return self.check_range(target, partition_date, partition_date)

def check_range(self, target: Table, start_date: date, end_date: date) -> bool:
"""Check if data for given date range is present in target

:param target: target Table to check
:param start_date: Starting date of the range to check
:param end_date: Ending date of the range to check
:return: True if data for given date range is present, False otherwise
"""
if self.reader.table_exists(target.get_table_path()):
df = self.reader.get_table(
target.get_table_path(),
date_column=target.partition,
date_from=start_date,
date_to=end_date,
filters=target.filters,
)
data_exists = df.count() > 0
if (
data_exists
and (target.filters is None or target.filters == {})
and target.secondary_partitions is not None
):
logger.warning(
f"Overwriting {target.get_table_path()} completion status for {start_date} due to presence of "
f"secondary partitions and no filters."
)
data_exists = False
return data_exists
else:
logger.warning(f"Target table {target.get_table_path()} doesn't exist yet.")
return False

def _get_filters(self, target: Table, df: DataFrame) -> Dict:
if target.filters is not None:
return target.filters
elif target.secondary_partitions:
filters = {}
logger.info("Inferring target sub-partition values from generated data.")
row = df.select(*target.secondary_partitions).distinct().collect()[0]
for c in target.secondary_partitions:
filters[c] = row[c]
return filters
else:
return {}

def check_written(self, target: Table, partition_date: date, df: DataFrame) -> int:
"""Check how many records were written

:param target: target Table to check
:param partition_date: Date to check
:param df: DataFrame that was written, used to determine filters if not provided in config
:return: Number of records for given date
"""
filters = self._get_filters(target, df)
df = self.reader.get_table(
target.get_table_path(),
date_column=target.partition,
date_from=partition_date,
date_to=partition_date,
filters=filters,
)

return df.count()
96 changes: 68 additions & 28 deletions rialto/runner/date_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,75 @@
from typing import List

from dateutil.relativedelta import relativedelta
from loguru import logger

from rialto.runner.config_loader import ScheduleConfig
from rialto.runner.config_loader import PipelinesConfig, ScheduleConfig


class DateManager:
"""Date generation and shifts based on configuration"""

def __init__(self, config: PipelinesConfig, run_date: date = None):
if run_date:
run_date = self.str_to_date(run_date)
else:
run_date = date.today()

self.date_from = self.date_subtract(
input_date=run_date,
units=self.config.runner.watched_period_units,
value=self.config.runner.watched_period_value,
)

self.date_until = run_date

if self.date_from > self.date_until:
raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}")
logger.info(f"Running period set to: {self.date_from} - {self.date_until}")

def get_date_from(self) -> date:
"""Get starting date of the execution window"""
return self.date_from

def get_date_until(self) -> date:
"""Get ending date of the execution window"""
return self.date_until

@staticmethod
def str_to_date(str_date: str) -> date:
def str_to_date(self, str_date: str) -> date:
"""
Convert YYYY-MM-DD string to date

:param str_date: string date
:return: date
"""
return datetime.strptime(str_date, "%Y-%m-%d").date()
try:
return datetime.strptime(str_date, "%Y-%m-%d").date()
except ValueError:
raise ValueError(f"Invalid date format: {str_date}. Expected YYYY-MM-DD.")

@staticmethod
def date_subtract(run_date: date, units: str, value: int) -> date:
def date_subtract(self, input_date: date, units: str, value: int) -> date:
"""
Generate starting date from given date and config
Subtract given number of units from input date

:param run_date: base date
:param input_date: base date
:param units: units: years, months, weeks, days
:param value: number of units to subtract
:return: Starting date
"""
if units == "years":
return run_date - relativedelta(years=value)
return input_date - relativedelta(years=value)
if units == "months":
return run_date - relativedelta(months=value)
return input_date - relativedelta(months=value)
if units == "weeks":
return run_date - relativedelta(weeks=value)
return input_date - relativedelta(weeks=value)
if units == "days":
return run_date - relativedelta(days=value)
return input_date - relativedelta(days=value)
raise ValueError(f"Unknown time unit {units}")

@staticmethod
def all_dates(date_from: date, date_to: date) -> List[date]:
def all_dates(self, date_from: date, date_to: date) -> List[date]:
"""
Get list of all dates between, inclusive

Expand All @@ -69,39 +99,49 @@ def all_dates(date_from: date, date_to: date) -> List[date]:

return [date_from + relativedelta(days=n) for n in range((date_to - date_from).days + 1)]

@staticmethod
def run_dates(date_from: date, date_to: date, schedule: ScheduleConfig) -> List[date]:
def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[tuple[date, date]]:
"""
Get list of execution and partition dates for given configuration

:return: List of tuples with execution and partition dates
"""
execution_dates = self._execution_dates(schedule)
return [(ex_date, self._to_partition_date(ex_date, schedule)) for ex_date in execution_dates]

def _execution_dates(self, schedule: ScheduleConfig) -> List[date]:
"""
Select dates inside given interval depending on frequency and selected day

:param date_from: interval start
:param date_to: interval end
:param schedule: schedule config
:return: list of dates
:return: List of execution dates
"""
options = DateManager.all_dates(date_from, date_to)
if schedule.frequency == "daily":
options = self.all_dates(self.date_from, self.date_until)
frequency = schedule.frequency.lower()
if frequency == "daily":
return options
if schedule.frequency == "weekly":
if frequency == "weekly":
if not (1 <= schedule.day <= 7):
raise ValueError(f"Invalid day for weekly frequency: {schedule.day}. Must be 1-7.")
return [x for x in options if x.isoweekday() == schedule.day]
if schedule.frequency == "monthly":
if frequency == "monthly":
if schedule.day == "last":
return [x for x in options if (x + relativedelta(days=1)).month != x.month]
if not (1 <= schedule.day <= 31):
raise ValueError(f"Invalid day for monthly frequency: {schedule.day}. Must be 1-31 or last.")
return [x for x in options if x.day == schedule.day]
raise ValueError(f"Unknown frequency {schedule.frequency}")
raise ValueError(f"Unknown frequency: {schedule.frequency}")

@staticmethod
def to_info_date(date: date, schedule: ScheduleConfig) -> date:
def _to_partition_date(self, date: date, schedule: ScheduleConfig) -> date:
"""
Shift given date according to config

:param date: input date
:param schedule: schedule config
:return: date
"""
if isinstance(schedule.info_date_shift, List):
if isinstance(schedule.info_date_shift, list):
for shift in schedule.info_date_shift:
date = DateManager.date_subtract(date, units=shift.units, value=shift.value)
date = self.date_subtract(date, units=shift.units, value=shift.value)
else:
date = DateManager.date_subtract(
date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value
)
date = self.date_subtract(date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value)
return date
Loading
Loading