Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 52 additions & 50 deletions data_rentgen/consumer/extractors/batch_extraction_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Hashable
from typing import TypeVar

from data_rentgen.dto import (
DTO,
ColumnLineageDTO,
DatasetDTO,
DatasetSymlinkDTO,
DatasetSymlinkGroupDTO,
InputDTO,
JobDependencyDTO,
JobDTO,
Expand Down Expand Up @@ -47,29 +47,29 @@ class BatchExtractionResult:
"""

def __init__(self):
self._locations: dict[tuple, LocationDTO] = {}
self._datasets: dict[tuple, DatasetDTO] = {}
self._dataset_symlinks: dict[tuple, DatasetSymlinkDTO] = {}
self._job_types: dict[tuple, JobTypeDTO] = {}
self._jobs: dict[tuple, JobDTO] = {}
self._job_dependencies: dict[tuple, JobDependencyDTO] = {}
self._runs: dict[tuple, RunDTO] = {}
self._operations: dict[tuple, OperationDTO] = {}
self._inputs: dict[tuple, InputDTO] = {}
self._outputs: dict[tuple, OutputDTO] = {}
self._column_lineage: dict[tuple, ColumnLineageDTO] = {}
self._schemas: dict[tuple, SchemaDTO] = {}
self._sql_queries: dict[tuple, SQLQueryDTO] = {}
self._tags: dict[tuple, TagDTO] = {}
self._tag_values: dict[tuple, TagValueDTO] = {}
self._users: dict[tuple, UserDTO] = {}
self._locations: dict[Hashable, LocationDTO] = {}
self._datasets: dict[Hashable, DatasetDTO] = {}
self._dataset_symlink_groups: dict[Hashable, DatasetSymlinkGroupDTO] = {}
self._job_types: dict[Hashable, JobTypeDTO] = {}
self._jobs: dict[Hashable, JobDTO] = {}
self._job_dependencies: dict[Hashable, JobDependencyDTO] = {}
self._runs: dict[Hashable, RunDTO] = {}
self._operations: dict[Hashable, OperationDTO] = {}
self._inputs: dict[Hashable, InputDTO] = {}
self._outputs: dict[Hashable, OutputDTO] = {}
self._column_lineage: dict[Hashable, ColumnLineageDTO] = {}
self._schemas: dict[Hashable, SchemaDTO] = {}
self._sql_queries: dict[Hashable, SQLQueryDTO] = {}
self._tags: dict[Hashable, TagDTO] = {}
self._tag_values: dict[Hashable, TagValueDTO] = {}
self._users: dict[Hashable, UserDTO] = {}

def __repr__(self):
return (
"ExtractionResult("
f"locations={len(self._locations)}, "
f"datasets={len(self._datasets)}, "
f"dataset_symlinks={len(self._dataset_symlinks)}, "
f"dataset_symlink_groups={len(self._dataset_symlink_groups)}, "
f"job_types={len(self._job_types)}, "
f"jobs={len(self._jobs)}, "
f"job_dependencies={len(self._job_dependencies)}, "
Expand All @@ -87,7 +87,7 @@ def __repr__(self):
)

@staticmethod
def _add(context: dict[tuple, T], new_item: T) -> T:
def _add(context: dict[Hashable, T], new_item: T) -> T:
key = new_item.unique_key
old_item = context.get(key)
if not old_item:
Expand All @@ -110,10 +110,11 @@ def add_dataset(self, dataset: DatasetDTO):
dataset.tag_values = {self.add_tag_value(tag_value) for tag_value in dataset.tag_values}
return self._add(self._datasets, dataset)

def add_dataset_symlink(self, dataset_symlink: DatasetSymlinkDTO):
dataset_symlink.from_dataset = self.add_dataset(dataset_symlink.from_dataset)
dataset_symlink.to_dataset = self.add_dataset(dataset_symlink.to_dataset)
return self._add(self._dataset_symlinks, dataset_symlink)
def add_dataset_symlink_group(self, dataset_symlink_group: DatasetSymlinkGroupDTO):
dataset_symlink_group.members = [
(self.add_dataset(dataset), role) for dataset, role in dataset_symlink_group.members
]
return self._add(self._dataset_symlink_groups, dataset_symlink_group)

def add_job_type(self, job_type: JobTypeDTO):
return self._add(self._job_types, job_type)
Expand Down Expand Up @@ -186,40 +187,41 @@ def add_tag_value(self, tag_value: TagValueDTO):
def add_user(self, user: UserDTO):
return self._add(self._users, user)

def get_location(self, location_key: tuple) -> LocationDTO:
def get_location(self, location_key: Hashable) -> LocationDTO:
return self._locations[location_key]

def get_schema(self, schema_key: tuple) -> SchemaDTO:
def get_schema(self, schema_key: Hashable) -> SchemaDTO:
return self._schemas[schema_key]

def get_sql_query(self, sql_query_key: tuple) -> SQLQueryDTO:
def get_sql_query(self, sql_query_key: Hashable) -> SQLQueryDTO:
return self._sql_queries[sql_query_key]

def get_user(self, user_key: tuple) -> UserDTO:
def get_user(self, user_key: Hashable) -> UserDTO:
return self._users[user_key]

def get_tag(self, tag_key: tuple) -> TagDTO:
def get_tag(self, tag_key: Hashable) -> TagDTO:
return self._tags[tag_key]

def get_tag_value(self, tag_value_key: tuple) -> TagValueDTO:
def get_tag_value(self, tag_value_key: Hashable) -> TagValueDTO:
return self._tag_values[tag_value_key]

def get_dataset(self, dataset_key: tuple) -> DatasetDTO:
def get_dataset(self, dataset_key: Hashable) -> DatasetDTO:
dataset = self._datasets[dataset_key]
dataset.location = self.get_location(dataset.location.unique_key)
dataset.tag_values = {self.get_tag_value(tag_value.unique_key) for tag_value in dataset.tag_values}
return dataset

def get_dataset_symlink(self, dataset_symlink_key: tuple) -> DatasetSymlinkDTO:
dataset_symlink = self._dataset_symlinks[dataset_symlink_key]
dataset_symlink.from_dataset = self.get_dataset(dataset_symlink.from_dataset.unique_key)
dataset_symlink.to_dataset = self.get_dataset(dataset_symlink.to_dataset.unique_key)
return dataset_symlink
def get_dataset_symlink_group(self, dataset_symlink_group_key: Hashable) -> DatasetSymlinkGroupDTO:
dataset_symlink_group = self._dataset_symlink_groups[dataset_symlink_group_key]
dataset_symlink_group.members = [
(self.get_dataset(dataset.unique_key), role) for dataset, role in dataset_symlink_group.members
]
return dataset_symlink_group

def get_job_type(self, job_type_key: tuple) -> JobTypeDTO:
def get_job_type(self, job_type_key: Hashable) -> JobTypeDTO:
return self._job_types[job_type_key]

def get_job(self, job_key: tuple) -> JobDTO:
def get_job(self, job_key: Hashable) -> JobDTO:
job = self._jobs[job_key]
job.location = self.get_location(job.location.unique_key)
if job.type:
Expand All @@ -229,13 +231,13 @@ def get_job(self, job_key: tuple) -> JobDTO:
job.tag_values = {self.get_tag_value(tag_value.unique_key) for tag_value in job.tag_values}
return job

def get_job_dependency(self, job_dependency_key: tuple) -> JobDependencyDTO:
def get_job_dependency(self, job_dependency_key: Hashable) -> JobDependencyDTO:
job_dependency = self._job_dependencies[job_dependency_key]
job_dependency.from_job = self.get_job(job_dependency.from_job.unique_key)
job_dependency.to_job = self.get_job(job_dependency.to_job.unique_key)
return job_dependency

def get_run(self, run_key: tuple) -> RunDTO:
def get_run(self, run_key: Hashable) -> RunDTO:
run = self._runs[run_key]
run.job = self.get_job(run.job.unique_key)
if run.parent_run:
Expand All @@ -244,48 +246,48 @@ def get_run(self, run_key: tuple) -> RunDTO:
run.user = self.get_user(run.user.unique_key)
return run

def get_operation(self, operation_key: tuple) -> OperationDTO:
def get_operation(self, operation_key: Hashable) -> OperationDTO:
operation = self._operations[operation_key]
operation.run = self.get_run(operation.run.unique_key)
return operation

def get_input(self, input_key: tuple) -> InputDTO:
def get_input(self, input_key: Hashable) -> InputDTO:
input_ = self._inputs[input_key]
input_.operation = self.get_operation(input_.operation.unique_key)
input_.dataset = self.get_dataset(input_.dataset.unique_key)
if input_.schema:
input_.schema = self.get_schema(input_.schema.unique_key)
return input_

def get_output(self, output_key: tuple) -> OutputDTO:
def get_output(self, output_key: Hashable) -> OutputDTO:
output = self._outputs[output_key]
output.operation = self.get_operation(output.operation.unique_key)
output.dataset = self.get_dataset(output.dataset.unique_key)
if output.schema:
output.schema = self.get_schema(output.schema.unique_key)
return output

def get_column_lineage(self, output_key: tuple) -> ColumnLineageDTO:
def get_column_lineage(self, output_key: Hashable) -> ColumnLineageDTO:
lineage = self._column_lineage[output_key]
lineage.operation = self.get_operation(lineage.operation.unique_key)
lineage.source_dataset = self.get_dataset(lineage.source_dataset.unique_key)
lineage.target_dataset = self.get_dataset(lineage.target_dataset.unique_key)
return lineage

@staticmethod
def _resolve(getter: Callable[[tuple], T], items: dict[tuple, T]) -> list[T]:
def _resolve(getter: Callable[[Hashable], T], items: dict[Hashable, T]) -> list[T]:
resolved = list(map(getter, items))
unique = {item.unique_key: item for item in resolved}
return [unique[key] for key in sorted(unique.keys())]
return [unique[key] for key in sorted(unique.keys(), key=str)]

def locations(self) -> list[LocationDTO]:
return self._resolve(self.get_location, self._locations)

def datasets(self) -> list[DatasetDTO]:
return self._resolve(self.get_dataset, self._datasets)

def dataset_symlinks(self) -> list[DatasetSymlinkDTO]:
return self._resolve(self.get_dataset_symlink, self._dataset_symlinks)
def dataset_symlink_groups(self) -> list[DatasetSymlinkGroupDTO]:
return self._resolve(self.get_dataset_symlink_group, self._dataset_symlink_groups)

def job_types(self) -> list[JobTypeDTO]:
return self._resolve(self.get_job_type, self._job_types)
Expand Down Expand Up @@ -333,8 +335,8 @@ def merge(self, other: BatchExtractionResult) -> BatchExtractionResult: # noqa:
for dataset in other.datasets():
self.add_dataset(dataset)

for dataset_symlink in other.dataset_symlinks():
self.add_dataset_symlink(dataset_symlink)
for dataset_symlink_group in other.dataset_symlink_groups():
self.add_dataset_symlink_group(dataset_symlink_group)

for job_type in other.job_types():
self.add_job_type(job_type)
Expand Down
12 changes: 6 additions & 6 deletions data_rentgen/consumer/extractors/batch_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ def _add_operation(self, event: OpenLineageRunEvent, extractor: ExtractorInterfa
self.result.add_operation(operation)

for input_dataset in event.inputs:
input_dto, symlink_dtos = extractor.extract_input(operation, input_dataset, event)
input_dto, symlink_groups = extractor.extract_input(operation, input_dataset, event)
self.result.add_input(input_dto)

for symlink_dto in symlink_dtos:
self.result.add_dataset_symlink(symlink_dto)
for symlink_group in symlink_groups:
self.result.add_dataset_symlink_group(symlink_group)

for output_dataset in event.outputs:
output_dto, symlink_dtos = extractor.extract_output(operation, output_dataset, event)
output_dto, symlink_groups = extractor.extract_output(operation, output_dataset, event)
self.result.add_output(output_dto)

for symlink_dto in symlink_dtos:
self.result.add_dataset_symlink(symlink_dto)
for symlink_group in symlink_groups:
self.result.add_dataset_symlink_group(symlink_group)

column_lineage = extractor.extract_column_lineage(operation, output_dataset, event)
for item in column_lineage:
Expand Down
63 changes: 26 additions & 37 deletions data_rentgen/consumer/extractors/generic/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from data_rentgen.dto import (
DatasetDTO,
DatasetSymlinkDTO,
DatasetSymlinkGroupDTO,
DatasetSymlinkTypeDTO,
LocationDTO,
TagDTO,
Expand All @@ -30,6 +30,14 @@
SCHEMALESS_LOCATION_TYPES = {"clickhouse", "mysql"}


def _get_symlink_role(type_: OpenLineageSymlinkType) -> DatasetSymlinkTypeDTO:
return METASTORE if type_ == OpenLineageSymlinkType.TABLE else WAREHOUSE


def _get_opposite_dataset_role(symlink_roles: list[DatasetSymlinkTypeDTO]) -> DatasetSymlinkTypeDTO:
return WAREHOUSE if METASTORE in symlink_roles else METASTORE


class DatasetExtractorMixin:
def extract_dataset(self, dataset: OpenLineageDataset) -> DatasetDTO:
"""
Expand Down Expand Up @@ -73,53 +81,34 @@ def _extract_dataset_location(
addresses={f"{scheme}://{host}" for host in hosts},
)

def extract_dataset_and_symlinks(self, dataset: OpenLineageDataset) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]:
def extract_dataset_and_symlinks(
self,
dataset: OpenLineageDataset,
) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]:
symlink_identifiers = dataset.facets.symlinks.identifiers if dataset.facets.symlinks else []
return self._extract_dataset_and_symlinks(dataset, symlink_identifiers)

def _extract_dataset_and_symlinks(
self,
dataset: OpenLineageDataset,
symlink_identifiers: list[OpenLineageSymlinkIdentifier],
) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]:
) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]:
dataset_dto = self.extract_dataset(dataset)
symlinks = []
for symlink_identifier in symlink_identifiers:
symlink_dto = self._extract_dataset_ref(symlink_identifier)
symlinks.extend(
self._connect_dataset_with_symlinks(
dataset_dto,
symlink_dto,
symlink_identifier.type,
),
)
return dataset_dto, symlinks
symlinks = [
(self._extract_dataset_ref(symlink_identifier), symlink_identifier.type)
for symlink_identifier in symlink_identifiers
]
return dataset_dto, [self._build_dataset_symlink_group(dataset_dto, symlinks)] if symlinks else []

def _connect_dataset_with_symlinks(
def _build_dataset_symlink_group(
self,
dataset: DatasetDTO,
symlink: DatasetDTO,
type_: OpenLineageSymlinkType,
) -> list[DatasetSymlinkDTO]:
result = []
is_metastore_symlink = type_ == OpenLineageSymlinkType.TABLE

result.append(
DatasetSymlinkDTO(
from_dataset=dataset,
to_dataset=symlink,
type=METASTORE if is_metastore_symlink else WAREHOUSE,
),
)
result.append(
DatasetSymlinkDTO(
from_dataset=symlink,
to_dataset=dataset,
type=WAREHOUSE if is_metastore_symlink else METASTORE,
),
)

return sorted(result, key=lambda x: x.type)
symlinks: list[tuple[DatasetDTO, OpenLineageSymlinkType]],
) -> DatasetSymlinkGroupDTO:
symlink_members = [(symlink, _get_symlink_role(type_)) for symlink, type_ in symlinks]
dataset_role = _get_opposite_dataset_role([role for _, role in symlink_members])
members = [(dataset, dataset_role), *symlink_members]
return DatasetSymlinkGroupDTO(members=members)

def _enrich_dataset_tags(self, dataset_dto: DatasetDTO, dataset: OpenLineageDataset) -> DatasetDTO:
if not dataset.facets.tags:
Expand Down
11 changes: 7 additions & 4 deletions data_rentgen/consumer/extractors/generic/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from data_rentgen.dto import (
DatasetDTO,
DatasetSymlinkDTO,
DatasetSymlinkGroupDTO,
DatasetSymlinkTypeDTO,
InputDTO,
OperationDTO,
Expand Down Expand Up @@ -51,7 +51,10 @@ class IOExtractorMixin(ABC):
io_time_resolution = timedelta(hours=1)

@abstractmethod
def extract_dataset_and_symlinks(self, dataset: OpenLineageDataset) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]:
def extract_dataset_and_symlinks(
self,
dataset: OpenLineageDataset,
) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]:
pass

def extract_io_created_at(self, operation: OperationDTO, event: OpenLineageRunEvent) -> datetime:
Expand Down Expand Up @@ -106,7 +109,7 @@ def extract_input(
operation: OperationDTO,
dataset: OpenLineageInputDataset,
event: OpenLineageRunEvent,
) -> tuple[InputDTO, list[DatasetSymlinkDTO]]:
) -> tuple[InputDTO, list[DatasetSymlinkGroupDTO]]:
"""
Extract InputDTO with optional symlinks
"""
Expand All @@ -130,7 +133,7 @@ def extract_output(
operation: OperationDTO,
dataset: OpenLineageOutputDataset,
event: OpenLineageRunEvent,
) -> tuple[OutputDTO, list[DatasetSymlinkDTO]]:
) -> tuple[OutputDTO, list[DatasetSymlinkGroupDTO]]:
"""
Extract OutputDTO with optional symlinks
"""
Expand Down
Loading
Loading