diff --git a/data_rentgen/consumer/extractors/batch_extraction_result.py b/data_rentgen/consumer/extractors/batch_extraction_result.py index 6a1b7e99..c8ae964d 100644 --- a/data_rentgen/consumer/extractors/batch_extraction_result.py +++ b/data_rentgen/consumer/extractors/batch_extraction_result.py @@ -2,14 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Hashable from typing import TypeVar from data_rentgen.dto import ( DTO, ColumnLineageDTO, DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, InputDTO, JobDependencyDTO, JobDTO, @@ -47,29 +47,29 @@ class BatchExtractionResult: """ def __init__(self): - self._locations: dict[tuple, LocationDTO] = {} - self._datasets: dict[tuple, DatasetDTO] = {} - self._dataset_symlinks: dict[tuple, DatasetSymlinkDTO] = {} - self._job_types: dict[tuple, JobTypeDTO] = {} - self._jobs: dict[tuple, JobDTO] = {} - self._job_dependencies: dict[tuple, JobDependencyDTO] = {} - self._runs: dict[tuple, RunDTO] = {} - self._operations: dict[tuple, OperationDTO] = {} - self._inputs: dict[tuple, InputDTO] = {} - self._outputs: dict[tuple, OutputDTO] = {} - self._column_lineage: dict[tuple, ColumnLineageDTO] = {} - self._schemas: dict[tuple, SchemaDTO] = {} - self._sql_queries: dict[tuple, SQLQueryDTO] = {} - self._tags: dict[tuple, TagDTO] = {} - self._tag_values: dict[tuple, TagValueDTO] = {} - self._users: dict[tuple, UserDTO] = {} + self._locations: dict[Hashable, LocationDTO] = {} + self._datasets: dict[Hashable, DatasetDTO] = {} + self._dataset_symlink_groups: dict[Hashable, DatasetSymlinkGroupDTO] = {} + self._job_types: dict[Hashable, JobTypeDTO] = {} + self._jobs: dict[Hashable, JobDTO] = {} + self._job_dependencies: dict[Hashable, JobDependencyDTO] = {} + self._runs: dict[Hashable, RunDTO] = {} + self._operations: dict[Hashable, OperationDTO] = {} + self._inputs: dict[Hashable, InputDTO] = {} + self._outputs: dict[Hashable, OutputDTO] = {} + self._column_lineage: dict[Hashable, ColumnLineageDTO] = {} + self._schemas: dict[Hashable, SchemaDTO] = {} + self._sql_queries: dict[Hashable, SQLQueryDTO] = {} + self._tags: dict[Hashable, TagDTO] = {} + self._tag_values: dict[Hashable, TagValueDTO] = {} + self._users: dict[Hashable, UserDTO] = {} def __repr__(self): return ( "ExtractionResult(" f"locations={len(self._locations)}, " f"datasets={len(self._datasets)}, " - f"dataset_symlinks={len(self._dataset_symlinks)}, " + f"dataset_symlink_groups={len(self._dataset_symlink_groups)}, " f"job_types={len(self._job_types)}, " f"jobs={len(self._jobs)}, " f"job_dependencies={len(self._job_dependencies)}, " @@ -87,7 +87,7 @@ def __repr__(self): ) @staticmethod - def _add(context: dict[tuple, T], new_item: T) -> T: + def _add(context: dict[Hashable, T], new_item: T) -> T: key = new_item.unique_key old_item = context.get(key) if not old_item: @@ -110,10 +110,11 @@ def add_dataset(self, dataset: DatasetDTO): dataset.tag_values = {self.add_tag_value(tag_value) for tag_value in dataset.tag_values} return self._add(self._datasets, dataset) - def add_dataset_symlink(self, dataset_symlink: DatasetSymlinkDTO): - dataset_symlink.from_dataset = self.add_dataset(dataset_symlink.from_dataset) - dataset_symlink.to_dataset = self.add_dataset(dataset_symlink.to_dataset) - return self._add(self._dataset_symlinks, dataset_symlink) + def add_dataset_symlink_group(self, dataset_symlink_group: DatasetSymlinkGroupDTO): + dataset_symlink_group.members = [ + (self.add_dataset(dataset), role) for dataset, role in dataset_symlink_group.members + ] + return self._add(self._dataset_symlink_groups, dataset_symlink_group) def add_job_type(self, job_type: JobTypeDTO): return self._add(self._job_types, job_type) @@ -186,40 +187,41 @@ def add_tag_value(self, tag_value: TagValueDTO): def add_user(self, user: UserDTO): return self._add(self._users, user) - def get_location(self, location_key: tuple) -> LocationDTO: + def get_location(self, location_key: Hashable) -> LocationDTO: return self._locations[location_key] - def get_schema(self, schema_key: tuple) -> SchemaDTO: + def get_schema(self, schema_key: Hashable) -> SchemaDTO: return self._schemas[schema_key] - def get_sql_query(self, sql_query_key: tuple) -> SQLQueryDTO: + def get_sql_query(self, sql_query_key: Hashable) -> SQLQueryDTO: return self._sql_queries[sql_query_key] - def get_user(self, user_key: tuple) -> UserDTO: + def get_user(self, user_key: Hashable) -> UserDTO: return self._users[user_key] - def get_tag(self, tag_key: tuple) -> TagDTO: + def get_tag(self, tag_key: Hashable) -> TagDTO: return self._tags[tag_key] - def get_tag_value(self, tag_value_key: tuple) -> TagValueDTO: + def get_tag_value(self, tag_value_key: Hashable) -> TagValueDTO: return self._tag_values[tag_value_key] - def get_dataset(self, dataset_key: tuple) -> DatasetDTO: + def get_dataset(self, dataset_key: Hashable) -> DatasetDTO: dataset = self._datasets[dataset_key] dataset.location = self.get_location(dataset.location.unique_key) dataset.tag_values = {self.get_tag_value(tag_value.unique_key) for tag_value in dataset.tag_values} return dataset - def get_dataset_symlink(self, dataset_symlink_key: tuple) -> DatasetSymlinkDTO: - dataset_symlink = self._dataset_symlinks[dataset_symlink_key] - dataset_symlink.from_dataset = self.get_dataset(dataset_symlink.from_dataset.unique_key) - dataset_symlink.to_dataset = self.get_dataset(dataset_symlink.to_dataset.unique_key) - return dataset_symlink + def get_dataset_symlink_group(self, dataset_symlink_group_key: Hashable) -> DatasetSymlinkGroupDTO: + dataset_symlink_group = self._dataset_symlink_groups[dataset_symlink_group_key] + dataset_symlink_group.members = [ + (self.get_dataset(dataset.unique_key), role) for dataset, role in dataset_symlink_group.members + ] + return dataset_symlink_group - def get_job_type(self, job_type_key: tuple) -> JobTypeDTO: + def get_job_type(self, job_type_key: Hashable) -> JobTypeDTO: return self._job_types[job_type_key] - def get_job(self, job_key: tuple) -> JobDTO: + def get_job(self, job_key: Hashable) -> JobDTO: job = self._jobs[job_key] job.location = self.get_location(job.location.unique_key) if job.type: @@ -229,13 +231,13 @@ def get_job(self, job_key: tuple) -> JobDTO: job.tag_values = {self.get_tag_value(tag_value.unique_key) for tag_value in job.tag_values} return job - def get_job_dependency(self, job_dependency_key: tuple) -> JobDependencyDTO: + def get_job_dependency(self, job_dependency_key: Hashable) -> JobDependencyDTO: job_dependency = self._job_dependencies[job_dependency_key] job_dependency.from_job = self.get_job(job_dependency.from_job.unique_key) job_dependency.to_job = self.get_job(job_dependency.to_job.unique_key) return job_dependency - def get_run(self, run_key: tuple) -> RunDTO: + def get_run(self, run_key: Hashable) -> RunDTO: run = self._runs[run_key] run.job = self.get_job(run.job.unique_key) if run.parent_run: @@ -244,12 +246,12 @@ def get_run(self, run_key: tuple) -> RunDTO: run.user = self.get_user(run.user.unique_key) return run - def get_operation(self, operation_key: tuple) -> OperationDTO: + def get_operation(self, operation_key: Hashable) -> OperationDTO: operation = self._operations[operation_key] operation.run = self.get_run(operation.run.unique_key) return operation - def get_input(self, input_key: tuple) -> InputDTO: + def get_input(self, input_key: Hashable) -> InputDTO: input_ = self._inputs[input_key] input_.operation = self.get_operation(input_.operation.unique_key) input_.dataset = self.get_dataset(input_.dataset.unique_key) @@ -257,7 +259,7 @@ def get_input(self, input_key: tuple) -> InputDTO: input_.schema = self.get_schema(input_.schema.unique_key) return input_ - def get_output(self, output_key: tuple) -> OutputDTO: + def get_output(self, output_key: Hashable) -> OutputDTO: output = self._outputs[output_key] output.operation = self.get_operation(output.operation.unique_key) output.dataset = self.get_dataset(output.dataset.unique_key) @@ -265,7 +267,7 @@ def get_output(self, output_key: tuple) -> OutputDTO: output.schema = self.get_schema(output.schema.unique_key) return output - def get_column_lineage(self, output_key: tuple) -> ColumnLineageDTO: + def get_column_lineage(self, output_key: Hashable) -> ColumnLineageDTO: lineage = self._column_lineage[output_key] lineage.operation = self.get_operation(lineage.operation.unique_key) lineage.source_dataset = self.get_dataset(lineage.source_dataset.unique_key) @@ -273,10 +275,10 @@ def get_column_lineage(self, output_key: tuple) -> ColumnLineageDTO: return lineage @staticmethod - def _resolve(getter: Callable[[tuple], T], items: dict[tuple, T]) -> list[T]: + def _resolve(getter: Callable[[Hashable], T], items: dict[Hashable, T]) -> list[T]: resolved = list(map(getter, items)) unique = {item.unique_key: item for item in resolved} - return [unique[key] for key in sorted(unique.keys())] + return [unique[key] for key in sorted(unique.keys(), key=str)] def locations(self) -> list[LocationDTO]: return self._resolve(self.get_location, self._locations) @@ -284,8 +286,8 @@ def locations(self) -> list[LocationDTO]: def datasets(self) -> list[DatasetDTO]: return self._resolve(self.get_dataset, self._datasets) - def dataset_symlinks(self) -> list[DatasetSymlinkDTO]: - return self._resolve(self.get_dataset_symlink, self._dataset_symlinks) + def dataset_symlink_groups(self) -> list[DatasetSymlinkGroupDTO]: + return self._resolve(self.get_dataset_symlink_group, self._dataset_symlink_groups) def job_types(self) -> list[JobTypeDTO]: return self._resolve(self.get_job_type, self._job_types) @@ -333,8 +335,8 @@ def merge(self, other: BatchExtractionResult) -> BatchExtractionResult: # noqa: for dataset in other.datasets(): self.add_dataset(dataset) - for dataset_symlink in other.dataset_symlinks(): - self.add_dataset_symlink(dataset_symlink) + for dataset_symlink_group in other.dataset_symlink_groups(): + self.add_dataset_symlink_group(dataset_symlink_group) for job_type in other.job_types(): self.add_job_type(job_type) diff --git a/data_rentgen/consumer/extractors/batch_extractor.py b/data_rentgen/consumer/extractors/batch_extractor.py index 1649d004..306a15cf 100644 --- a/data_rentgen/consumer/extractors/batch_extractor.py +++ b/data_rentgen/consumer/extractors/batch_extractor.py @@ -60,18 +60,18 @@ def _add_operation(self, event: OpenLineageRunEvent, extractor: ExtractorInterfa self.result.add_operation(operation) for input_dataset in event.inputs: - input_dto, symlink_dtos = extractor.extract_input(operation, input_dataset, event) + input_dto, symlink_groups = extractor.extract_input(operation, input_dataset, event) self.result.add_input(input_dto) - for symlink_dto in symlink_dtos: - self.result.add_dataset_symlink(symlink_dto) + for symlink_group in symlink_groups: + self.result.add_dataset_symlink_group(symlink_group) for output_dataset in event.outputs: - output_dto, symlink_dtos = extractor.extract_output(operation, output_dataset, event) + output_dto, symlink_groups = extractor.extract_output(operation, output_dataset, event) self.result.add_output(output_dto) - for symlink_dto in symlink_dtos: - self.result.add_dataset_symlink(symlink_dto) + for symlink_group in symlink_groups: + self.result.add_dataset_symlink_group(symlink_group) column_lineage = extractor.extract_column_lineage(operation, output_dataset, event) for item in column_lineage: diff --git a/data_rentgen/consumer/extractors/generic/dataset.py b/data_rentgen/consumer/extractors/generic/dataset.py index 57ee962e..27d91898 100644 --- a/data_rentgen/consumer/extractors/generic/dataset.py +++ b/data_rentgen/consumer/extractors/generic/dataset.py @@ -6,7 +6,7 @@ from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, DatasetSymlinkTypeDTO, LocationDTO, TagDTO, @@ -30,6 +30,14 @@ SCHEMALESS_LOCATION_TYPES = {"clickhouse", "mysql"} +def _get_symlink_role(type_: OpenLineageSymlinkType) -> DatasetSymlinkTypeDTO: + return METASTORE if type_ == OpenLineageSymlinkType.TABLE else WAREHOUSE + + +def _get_opposite_dataset_role(symlink_roles: list[DatasetSymlinkTypeDTO]) -> DatasetSymlinkTypeDTO: + return WAREHOUSE if METASTORE in symlink_roles else METASTORE + + class DatasetExtractorMixin: def extract_dataset(self, dataset: OpenLineageDataset) -> DatasetDTO: """ @@ -73,7 +81,10 @@ def _extract_dataset_location( addresses={f"{scheme}://{host}" for host in hosts}, ) - def extract_dataset_and_symlinks(self, dataset: OpenLineageDataset) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]: + def extract_dataset_and_symlinks( + self, + dataset: OpenLineageDataset, + ) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]: symlink_identifiers = dataset.facets.symlinks.identifiers if dataset.facets.symlinks else [] return self._extract_dataset_and_symlinks(dataset, symlink_identifiers) @@ -81,45 +92,23 @@ def _extract_dataset_and_symlinks( self, dataset: OpenLineageDataset, symlink_identifiers: list[OpenLineageSymlinkIdentifier], - ) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]: + ) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]: dataset_dto = self.extract_dataset(dataset) - symlinks = [] - for symlink_identifier in symlink_identifiers: - symlink_dto = self._extract_dataset_ref(symlink_identifier) - symlinks.extend( - self._connect_dataset_with_symlinks( - dataset_dto, - symlink_dto, - symlink_identifier.type, - ), - ) - return dataset_dto, symlinks + symlinks = [ + (self._extract_dataset_ref(symlink_identifier), symlink_identifier.type) + for symlink_identifier in symlink_identifiers + ] + return dataset_dto, [self._build_dataset_symlink_group(dataset_dto, symlinks)] if symlinks else [] - def _connect_dataset_with_symlinks( + def _build_dataset_symlink_group( self, dataset: DatasetDTO, - symlink: DatasetDTO, - type_: OpenLineageSymlinkType, - ) -> list[DatasetSymlinkDTO]: - result = [] - is_metastore_symlink = type_ == OpenLineageSymlinkType.TABLE - - result.append( - DatasetSymlinkDTO( - from_dataset=dataset, - to_dataset=symlink, - type=METASTORE if is_metastore_symlink else WAREHOUSE, - ), - ) - result.append( - DatasetSymlinkDTO( - from_dataset=symlink, - to_dataset=dataset, - type=WAREHOUSE if is_metastore_symlink else METASTORE, - ), - ) - - return sorted(result, key=lambda x: x.type) + symlinks: list[tuple[DatasetDTO, OpenLineageSymlinkType]], + ) -> DatasetSymlinkGroupDTO: + symlink_members = [(symlink, _get_symlink_role(type_)) for symlink, type_ in symlinks] + dataset_role = _get_opposite_dataset_role([role for _, role in symlink_members]) + members = [(dataset, dataset_role), *symlink_members] + return DatasetSymlinkGroupDTO(members=members) def _enrich_dataset_tags(self, dataset_dto: DatasetDTO, dataset: OpenLineageDataset) -> DatasetDTO: if not dataset.facets.tags: diff --git a/data_rentgen/consumer/extractors/generic/io.py b/data_rentgen/consumer/extractors/generic/io.py index c1c78d92..fe503b32 100644 --- a/data_rentgen/consumer/extractors/generic/io.py +++ b/data_rentgen/consumer/extractors/generic/io.py @@ -8,7 +8,7 @@ from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, DatasetSymlinkTypeDTO, InputDTO, OperationDTO, @@ -51,7 +51,10 @@ class IOExtractorMixin(ABC): io_time_resolution = timedelta(hours=1) @abstractmethod - def extract_dataset_and_symlinks(self, dataset: OpenLineageDataset) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]: + def extract_dataset_and_symlinks( + self, + dataset: OpenLineageDataset, + ) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]: pass def extract_io_created_at(self, operation: OperationDTO, event: OpenLineageRunEvent) -> datetime: @@ -106,7 +109,7 @@ def extract_input( operation: OperationDTO, dataset: OpenLineageInputDataset, event: OpenLineageRunEvent, - ) -> tuple[InputDTO, list[DatasetSymlinkDTO]]: + ) -> tuple[InputDTO, list[DatasetSymlinkGroupDTO]]: """ Extract InputDTO with optional symlinks """ @@ -130,7 +133,7 @@ def extract_output( operation: OperationDTO, dataset: OpenLineageOutputDataset, event: OpenLineageRunEvent, - ) -> tuple[OutputDTO, list[DatasetSymlinkDTO]]: + ) -> tuple[OutputDTO, list[DatasetSymlinkGroupDTO]]: """ Extract OutputDTO with optional symlinks """ diff --git a/data_rentgen/consumer/extractors/impl/flink.py b/data_rentgen/consumer/extractors/impl/flink.py index c56efadc..0ecc34a6 100644 --- a/data_rentgen/consumer/extractors/impl/flink.py +++ b/data_rentgen/consumer/extractors/impl/flink.py @@ -3,7 +3,7 @@ from __future__ import annotations from data_rentgen.consumer.extractors.generic import GenericExtractor -from data_rentgen.dto import DatasetDTO, DatasetSymlinkDTO, OperationDTO, OutputTypeDTO, RunDTO +from data_rentgen.dto import DatasetDTO, DatasetSymlinkGroupDTO, OperationDTO, OutputTypeDTO, RunDTO from data_rentgen.openlineage.dataset import OpenLineageDataset, OpenLineageOutputDataset from data_rentgen.openlineage.dataset_facets import ( OpenLineageSymlinkIdentifier, @@ -50,7 +50,7 @@ def _extract_dataset_and_symlinks( self, dataset: OpenLineageDataset, symlink_identifiers: list[OpenLineageSymlinkIdentifier], - ) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]: + ) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]: # Exclude Kafka fake symlinks produced by Flink 2.x integration. # See https://github.com/OpenLineage/OpenLineage/pull/3657 symlink_identifiers = [ diff --git a/data_rentgen/consumer/extractors/impl/interface.py b/data_rentgen/consumer/extractors/impl/interface.py index 782c33d2..c9a5c06f 100644 --- a/data_rentgen/consumer/extractors/impl/interface.py +++ b/data_rentgen/consumer/extractors/impl/interface.py @@ -4,7 +4,7 @@ from typing import Protocol -from data_rentgen.dto import ColumnLineageDTO, DatasetSymlinkDTO, InputDTO, OperationDTO, OutputDTO, RunDTO +from data_rentgen.dto import ColumnLineageDTO, DatasetSymlinkGroupDTO, InputDTO, OperationDTO, OutputDTO, RunDTO from data_rentgen.openlineage.dataset import ( OpenLineageInputDataset, OpenLineageOutputDataset, @@ -43,14 +43,14 @@ def extract_input( operation: OperationDTO, dataset: OpenLineageInputDataset, event: OpenLineageRunEvent, - ) -> tuple[InputDTO, list[DatasetSymlinkDTO]]: ... + ) -> tuple[InputDTO, list[DatasetSymlinkGroupDTO]]: ... def extract_output( self, operation: OperationDTO, dataset: OpenLineageOutputDataset, event: OpenLineageRunEvent, - ) -> tuple[OutputDTO, list[DatasetSymlinkDTO]]: ... + ) -> tuple[OutputDTO, list[DatasetSymlinkGroupDTO]]: ... def extract_column_lineage( self, diff --git a/data_rentgen/consumer/extractors/impl/spark.py b/data_rentgen/consumer/extractors/impl/spark.py index 60beaa11..219eb354 100644 --- a/data_rentgen/consumer/extractors/impl/spark.py +++ b/data_rentgen/consumer/extractors/impl/spark.py @@ -6,7 +6,7 @@ import re from data_rentgen.consumer.extractors.generic import GenericExtractor -from data_rentgen.dto import DatasetDTO, DatasetSymlinkDTO, OperationDTO, RunDTO, UserDTO +from data_rentgen.dto import DatasetDTO, DatasetSymlinkGroupDTO, OperationDTO, RunDTO, UserDTO from data_rentgen.openlineage.dataset import OpenLineageDataset from data_rentgen.openlineage.dataset_facets import ( OpenLineageColumnLineageDatasetFacetFieldRef, @@ -93,7 +93,7 @@ def _extract_dataset_and_symlinks( self, dataset: OpenLineageDataset, symlink_identifiers: list[OpenLineageSymlinkIdentifier], - ) -> tuple[DatasetDTO, list[DatasetSymlinkDTO]]: + ) -> tuple[DatasetDTO, list[DatasetSymlinkGroupDTO]]: table_symlinks = [ identifier for identifier in symlink_identifiers if identifier.type == OpenLineageSymlinkType.TABLE ] @@ -121,9 +121,10 @@ def _extract_dataset_and_symlinks( return ( table_dataset_dto, - self._connect_dataset_with_symlinks( - location_dataset_dto, - table_dataset_dto, - OpenLineageSymlinkType.TABLE, - ), + [ + self._build_dataset_symlink_group( + location_dataset_dto, + [(table_dataset_dto, OpenLineageSymlinkType.TABLE)], + ), + ], ) diff --git a/data_rentgen/consumer/saver.py b/data_rentgen/consumer/saver.py index de09c65d..8f8bd276 100644 --- a/data_rentgen/consumer/saver.py +++ b/data_rentgen/consumer/saver.py @@ -29,7 +29,7 @@ async def save(self, data: BatchExtractionResult): await self.create_tags(data) await self.create_tag_values(data) await self.create_datasets(data) - await self.create_dataset_symlinks(data) + await self.create_dataset_symlink_groups(data) await self.create_job_types(data) await self.create_jobs(data) await self.create_job_dependencies(data) @@ -72,14 +72,10 @@ async def create_datasets(self, data: BatchExtractionResult): dataset = await self.unit_of_work.dataset.update(dataset, dataset_dto) # noqa: PLW2901 dataset_dto.id = dataset.id - async def create_dataset_symlinks(self, data: BatchExtractionResult): + async def create_dataset_symlink_groups(self, data: BatchExtractionResult): self.logger.debug("Creating dataset symlinks") - dataset_symlinks_pairs = await self.unit_of_work.dataset_symlink.fetch_bulk(data.dataset_symlinks()) - for dataset_symlink_dto, dataset_symlink in dataset_symlinks_pairs: - if not dataset_symlink: - async with self.unit_of_work: - dataset_symlink = await self.unit_of_work.dataset_symlink.create(dataset_symlink_dto) # noqa: PLW2901 - dataset_symlink_dto.id = dataset_symlink.id + async with self.unit_of_work: + await self.unit_of_work.dataset_symlink.create_bulk(data.dataset_symlink_groups()) async def create_job_types(self, data: BatchExtractionResult): self.logger.debug("Creating job types") diff --git a/data_rentgen/db/migrations/versions/2026-06-29_4a02d2d5c8b1_create_dataset_symlink_group.py b/data_rentgen/db/migrations/versions/2026-06-29_4a02d2d5c8b1_create_dataset_symlink_group.py new file mode 100644 index 00000000..4a5d8d19 --- /dev/null +++ b/data_rentgen/db/migrations/versions/2026-06-29_4a02d2d5c8b1_create_dataset_symlink_group.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: 2024-present MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +"""Create dataset_symlink_group + +Revision ID: 4a02d2d5c8b1 +Revises: c3f8a2e1d749 +Create Date: 2026-06-29 19:40:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +from data_rentgen.dto import ( + DatasetDTO, + DatasetSymlinkTypeDTO, + LocationDTO, + compute_symlink_fingerprint, +) + +BACKFILL_BATCH_SIZE = 10_000 + +# revision identifiers, used by Alembic. +revision = "4a02d2d5c8b1" +down_revision = "c3f8a2e1d749" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "dataset_symlink_group", + sa.Column("dataset_id", sa.BigInteger(), autoincrement=False, nullable=False), + sa.Column("fingerprint", sa.UUID(), nullable=False), + sa.Column("type", sa.String(length=32), nullable=False), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["dataset.id"], + name=op.f("fk__dataset_symlink_group__dataset_id__dataset"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint( + "dataset_id", + "fingerprint", + name=op.f("pk__dataset_symlink_group"), + ), + ) + _backfill_symlink_groups() + op.drop_table("dataset_symlink") + op.execute( + sa.text( + """ + CREATE VIEW dataset_symlink AS + SELECT + a.dataset_id AS from_dataset_id, + b.dataset_id AS to_dataset_id, + b.type AS type + FROM dataset_symlink_group a + JOIN dataset_symlink_group b + ON a.fingerprint = b.fingerprint + AND a.dataset_id <> b.dataset_id + """, + ), + ) + + +def downgrade() -> None: + op.execute(sa.text("ALTER VIEW dataset_symlink RENAME TO dataset_symlink_view")) + op.create_table( + "dataset_symlink", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("from_dataset_id", sa.BigInteger(), nullable=False), + sa.Column("to_dataset_id", sa.BigInteger(), nullable=False), + sa.Column("type", sa.String(length=32), nullable=False), + sa.ForeignKeyConstraint( + ["from_dataset_id"], + ["dataset.id"], + name=op.f("fk__dataset_symlink__from_dataset_id__dataset"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["to_dataset_id"], + ["dataset.id"], + name=op.f("fk__dataset_symlink__to_dataset_id__dataset"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk__dataset_symlink")), + sa.UniqueConstraint( + "from_dataset_id", + "to_dataset_id", + name=op.f("uq__dataset_symlink__from_dataset_id_to_dataset_id"), + ), + ) + op.execute( + sa.text( + """ + INSERT INTO dataset_symlink (from_dataset_id, to_dataset_id, type) + SELECT from_dataset_id, to_dataset_id, type + FROM dataset_symlink_view + ON CONFLICT (from_dataset_id, to_dataset_id) DO NOTHING + """, + ), + ) + op.create_index(op.f("ix__dataset_symlink__from_dataset_id"), "dataset_symlink", ["from_dataset_id"]) + op.create_index(op.f("ix__dataset_symlink__to_dataset_id"), "dataset_symlink", ["to_dataset_id"]) + + op.execute(sa.text("DROP VIEW dataset_symlink_view")) + op.drop_table("dataset_symlink_group") + + +def _opposite_type(type_: DatasetSymlinkTypeDTO) -> DatasetSymlinkTypeDTO: + if type_ == DatasetSymlinkTypeDTO.METASTORE: + return DatasetSymlinkTypeDTO.WAREHOUSE + return DatasetSymlinkTypeDTO.METASTORE + + +def _backfill_symlink_groups() -> None: + bind = op.get_bind() + last_id = 0 + + query = sa.text( + """ + SELECT + ds.id AS id, + ds.from_dataset_id AS from_dataset_id, + ds.to_dataset_id AS to_dataset_id, + ds.type AS type, + from_location.type AS from_location_type, + from_location.name AS from_location_name, + from_dataset.name AS from_dataset_name, + to_location.type AS to_location_type, + to_location.name AS to_location_name, + to_dataset.name AS to_dataset_name + FROM dataset_symlink AS ds + JOIN dataset AS from_dataset ON from_dataset.id = ds.from_dataset_id + JOIN location AS from_location ON from_location.id = from_dataset.location_id + JOIN dataset AS to_dataset ON to_dataset.id = ds.to_dataset_id + JOIN location AS to_location ON to_location.id = to_dataset.location_id + WHERE ds.id > :last_id + ORDER BY ds.id + LIMIT :limit + """, + ) + + insert_query = sa.text( + """ + INSERT INTO dataset_symlink_group (fingerprint, dataset_id, type) + VALUES (:fingerprint, :dataset_id, :type) + ON CONFLICT (dataset_id, fingerprint) DO NOTHING + """, + ) + + while rows := bind.execute(query, {"last_id": last_id, "limit": BACKFILL_BATCH_SIZE}).fetchall(): + _insert_symlink_groups_batch(bind, insert_query, rows) + last_id = rows[-1].id + + +def _insert_symlink_groups_batch(bind, insert_query, rows) -> None: + params: list[dict] = [] + + for row in rows: + to_role = DatasetSymlinkTypeDTO(row.type) + from_role = _opposite_type(to_role) + + from_dataset = DatasetDTO( + name=row.from_dataset_name, + location=LocationDTO( + type=row.from_location_type, + name=row.from_location_name, + addresses=set(), + ), + ) + to_dataset = DatasetDTO( + name=row.to_dataset_name, + location=LocationDTO( + type=row.to_location_type, + name=row.to_location_name, + addresses=set(), + ), + ) + + fingerprint = compute_symlink_fingerprint( + [(from_dataset, from_role), (to_dataset, to_role)], + ) + + params.append( + { + "fingerprint": fingerprint, + "dataset_id": row.from_dataset_id, + "type": str(from_role), + }, + ) + params.append( + { + "fingerprint": fingerprint, + "dataset_id": row.to_dataset_id, + "type": str(to_role), + }, + ) + + if params: + bind.execute(insert_query, params) diff --git a/data_rentgen/db/models/__init__.py b/data_rentgen/db/models/__init__.py index 091963e5..622dacfb 100644 --- a/data_rentgen/db/models/__init__.py +++ b/data_rentgen/db/models/__init__.py @@ -12,6 +12,7 @@ DatasetColumnRelationType, ) from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType +from data_rentgen.db.models.dataset_symlink_group import DatasetSymlinkGroup from data_rentgen.db.models.input import Input from data_rentgen.db.models.job import Job, JobTagValue from data_rentgen.db.models.job_dependency import JobDependency @@ -37,6 +38,7 @@ "DatasetColumnRelation", "DatasetColumnRelationType", "DatasetSymlink", + "DatasetSymlinkGroup", "DatasetSymlinkType", "DatasetTagValue", "Input", diff --git a/data_rentgen/db/models/dataset_symlink.py b/data_rentgen/db/models/dataset_symlink.py index 664268db..92a82f29 100644 --- a/data_rentgen/db/models/dataset_symlink.py +++ b/data_rentgen/db/models/dataset_symlink.py @@ -2,12 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from sqlalchemy import BigInteger, ForeignKey, String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import BigInteger, ForeignKey, MetaData, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy_utils import ChoiceType -from data_rentgen.db.models.base import Base -from data_rentgen.db.models.dataset import Dataset +_view_metadata = MetaData() + + +class _ViewBase(DeclarativeBase): + metadata = _view_metadata class DatasetSymlinkType(str, Enum): @@ -18,37 +21,24 @@ def __str__(self) -> str: return self.value -class DatasetSymlink(Base): +class DatasetSymlink(_ViewBase): + """Read-only ORM mapping for the dataset_symlink VIEW.""" + __tablename__ = "dataset_symlink" - __table_args__ = (UniqueConstraint("from_dataset_id", "to_dataset_id"),) - id: Mapped[int] = mapped_column(BigInteger, primary_key=True) from_dataset_id: Mapped[int] = mapped_column( BigInteger, ForeignKey("dataset.id", ondelete="CASCADE"), - index=True, + primary_key=True, nullable=False, ) - from_dataset: Mapped[Dataset] = relationship( - Dataset, - lazy="noload", - foreign_keys=[from_dataset_id], - ) - to_dataset_id: Mapped[int] = mapped_column( BigInteger, ForeignKey("dataset.id", ondelete="CASCADE"), - index=True, + primary_key=True, nullable=False, ) - to_dataset: Mapped[Dataset] = relationship( - Dataset, - lazy="noload", - foreign_keys=[to_dataset_id], - ) - type: Mapped[DatasetSymlinkType] = mapped_column( ChoiceType(DatasetSymlinkType, impl=String(32)), nullable=False, - doc="Type of dataset symlink, e.g. metastore table -> hdfs location", ) diff --git a/data_rentgen/db/models/dataset_symlink_group.py b/data_rentgen/db/models/dataset_symlink_group.py new file mode 100644 index 00000000..591f8b4e --- /dev/null +++ b/data_rentgen/db/models/dataset_symlink_group.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: 2024-present MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +from uuid import UUID + +from sqlalchemy import UUID as SQL_UUID +from sqlalchemy import BigInteger, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy_utils import ChoiceType + +from data_rentgen.db.models.base import Base +from data_rentgen.db.models.dataset import Dataset +from data_rentgen.db.models.dataset_symlink import DatasetSymlinkType + + +class DatasetSymlinkGroup(Base): + __tablename__ = "dataset_symlink_group" + + dataset_id: Mapped[int] = mapped_column( + BigInteger, + ForeignKey("dataset.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + dataset: Mapped[Dataset] = relationship( + Dataset, + lazy="noload", + foreign_keys=[dataset_id], + ) + fingerprint: Mapped[UUID] = mapped_column(SQL_UUID, primary_key=True) + type: Mapped[DatasetSymlinkType] = mapped_column( + ChoiceType(DatasetSymlinkType, impl=String(32)), + nullable=False, + ) diff --git a/data_rentgen/db/repositories/dataset_symlink.py b/data_rentgen/db/repositories/dataset_symlink.py index da9a3a69..19598367 100644 --- a/data_rentgen/db/repositories/dataset_symlink.py +++ b/data_rentgen/db/repositories/dataset_symlink.py @@ -3,23 +3,16 @@ from collections.abc import Collection -from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, or_, select, tuple_ +from sqlalchemy import any_, bindparam, or_, select +from sqlalchemy.dialects.postgresql import insert from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType +from data_rentgen.db.models.dataset_symlink_group import DatasetSymlinkGroup from data_rentgen.db.repositories.base import Repository -from data_rentgen.dto import DatasetSymlinkDTO +from data_rentgen.dto import DatasetSymlinkGroupDTO -fetch_bulk_query = select(DatasetSymlink).where( - tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_( - select( - func.unnest( - cast(bindparam("from_dataset_ids"), ARRAY(Integer())), - cast(bindparam("to_dataset_ids"), ARRAY(Integer())), - ) - .table_valued("from_dataset_ids", "to_dataset_ids") - .render_derived(), - ), - ), +insert_group_query = insert(DatasetSymlinkGroup).on_conflict_do_nothing( + index_elements=[DatasetSymlinkGroup.dataset_id, DatasetSymlinkGroup.fingerprint], ) get_list_query = select(DatasetSymlink).where( @@ -29,44 +22,24 @@ ), ) -get_one_query = ( - select(DatasetSymlink) - .where( - DatasetSymlink.from_dataset_id == bindparam("from_dataset_id"), - DatasetSymlink.to_dataset_id == bindparam("to_dataset_id"), - ) - .limit(1) -) - - -class DatasetSymlinkRepository(Repository[DatasetSymlink]): - async def fetch_bulk( - self, - dataset_symlinks_dto: list[DatasetSymlinkDTO], - ) -> list[tuple[DatasetSymlinkDTO, DatasetSymlink | None]]: - if not dataset_symlinks_dto: - return [] - scalars = await self._session.scalars( - fetch_bulk_query, - { - "from_dataset_ids": [item.from_dataset.id for item in dataset_symlinks_dto], - "to_dataset_ids": [item.to_dataset.id for item in dataset_symlinks_dto], - }, +class DatasetSymlinkRepository(Repository[DatasetSymlinkGroup]): + async def create_bulk(self, items: list[DatasetSymlinkGroupDTO]): + if not items: + return + + await self._session.execute( + insert_group_query, + [ + { + "fingerprint": item.fingerprint, + "dataset_id": dataset.id, + "type": DatasetSymlinkType(type_), + } + for item in items + for dataset, type_ in item.members + ], ) - existing = {(item.from_dataset_id, item.to_dataset_id): item for item in scalars.all()} - return [ - ( - dto, - existing.get((dto.from_dataset.id, dto.to_dataset.id)), # type: ignore[arg-type] - ) - for dto in dataset_symlinks_dto - ] - - async def create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink: - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(dataset_symlink.from_dataset.id, dataset_symlink.to_dataset.id) - return await self._get(dataset_symlink) or await self._create(dataset_symlink) async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[DatasetSymlink]: if not dataset_ids: @@ -74,22 +47,3 @@ async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[Datase scalars = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)}) return list(scalars.all()) - - async def _get(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink | None: - return await self._session.scalar( - get_one_query, - { - "from_dataset_id": dataset_symlink.from_dataset.id, - "to_dataset_id": dataset_symlink.to_dataset.id, - }, - ) - - async def _create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink: - result = DatasetSymlink( - from_dataset_id=dataset_symlink.from_dataset.id, - to_dataset_id=dataset_symlink.to_dataset.id, - type=DatasetSymlinkType(dataset_symlink.type), - ) - self._session.add(result) - await self._session.flush([result]) - return result diff --git a/data_rentgen/db/scripts/seed/hive.py b/data_rentgen/db/scripts/seed/hive.py index 551a88eb..b4f02b6c 100644 --- a/data_rentgen/db/scripts/seed/hive.py +++ b/data_rentgen/db/scripts/seed/hive.py @@ -14,7 +14,7 @@ DatasetColumnRelationDTO, DatasetColumnRelationTypeDTO, DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, DatasetSymlinkTypeDTO, InputDTO, JobDTO, @@ -89,15 +89,11 @@ } DATASET_SYMLINKS = [ - DatasetSymlinkDTO( - from_dataset=DATASETS["hive_ref_user_info"], - to_dataset=DATASETS["hdfs_ref_user_info"], - type=DatasetSymlinkTypeDTO.WAREHOUSE, - ), - DatasetSymlinkDTO( - from_dataset=DATASETS["hdfs_ref_user_info"], - to_dataset=DATASETS["hive_ref_user_info"], - type=DatasetSymlinkTypeDTO.METASTORE, + DatasetSymlinkGroupDTO( + members=[ + (DATASETS["hive_ref_user_info"], DatasetSymlinkTypeDTO.METASTORE), + (DATASETS["hdfs_ref_user_info"], DatasetSymlinkTypeDTO.WAREHOUSE), + ], ), ] @@ -150,7 +146,7 @@ def generate_hive_run( result.add_run(run) for symlink in DATASET_SYMLINKS: - result.add_dataset_symlink(symlink) + result.add_dataset_symlink_group(symlink) for generator in [load_ref_user_info]: operation, inputs, outputs, column_lineage = generator(faker, run) diff --git a/data_rentgen/db/scripts/seed/spark_local.py b/data_rentgen/db/scripts/seed/spark_local.py index 74d9c704..e0d3b6ac 100644 --- a/data_rentgen/db/scripts/seed/spark_local.py +++ b/data_rentgen/db/scripts/seed/spark_local.py @@ -14,7 +14,7 @@ DatasetColumnRelationDTO, DatasetColumnRelationTypeDTO, DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, DatasetSymlinkTypeDTO, InputDTO, JobDTO, @@ -101,15 +101,11 @@ } DATASET_SYMLINKS = [ - DatasetSymlinkDTO( - from_dataset=DATASETS["hive_raw_user_metrics"], - to_dataset=DATASETS["hdfs_raw_user_metrics"], - type=DatasetSymlinkTypeDTO.WAREHOUSE, - ), - DatasetSymlinkDTO( - from_dataset=DATASETS["hdfs_raw_user_metrics"], - to_dataset=DATASETS["hive_raw_user_metrics"], - type=DatasetSymlinkTypeDTO.METASTORE, + DatasetSymlinkGroupDTO( + members=[ + (DATASETS["hive_raw_user_metrics"], DatasetSymlinkTypeDTO.METASTORE), + (DATASETS["hdfs_raw_user_metrics"], DatasetSymlinkTypeDTO.WAREHOUSE), + ], ), ] @@ -200,7 +196,7 @@ def generate_spark_run_local( result.add_run(run) for symlink in DATASET_SYMLINKS: - result.add_dataset_symlink(symlink) + result.add_dataset_symlink_group(symlink) for generator in [clickhouse_to_hive, postgres_to_hive]: operation, inputs, outputs, column_lineage = generator(faker, run) diff --git a/data_rentgen/db/scripts/seed/spark_yarn.py b/data_rentgen/db/scripts/seed/spark_yarn.py index a1bbc929..fdb1e29d 100644 --- a/data_rentgen/db/scripts/seed/spark_yarn.py +++ b/data_rentgen/db/scripts/seed/spark_yarn.py @@ -14,7 +14,7 @@ DatasetColumnRelationDTO, DatasetColumnRelationTypeDTO, DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, DatasetSymlinkTypeDTO, InputDTO, JobDependencyDTO, @@ -109,35 +109,23 @@ } DATASET_SYMLINKS = [ - DatasetSymlinkDTO( - from_dataset=DATASETS["hive_raw_user_metrics"], - to_dataset=DATASETS["hdfs_raw_user_metrics"], - type=DatasetSymlinkTypeDTO.WAREHOUSE, - ), - DatasetSymlinkDTO( - from_dataset=DATASETS["hdfs_raw_user_metrics"], - to_dataset=DATASETS["hive_raw_user_metrics"], - type=DatasetSymlinkTypeDTO.METASTORE, - ), - DatasetSymlinkDTO( - from_dataset=DATASETS["hive_ref_user_info"], - to_dataset=DATASETS["hdfs_ref_user_info"], - type=DatasetSymlinkTypeDTO.WAREHOUSE, - ), - DatasetSymlinkDTO( - from_dataset=DATASETS["hdfs_ref_user_info"], - to_dataset=DATASETS["hive_ref_user_info"], - type=DatasetSymlinkTypeDTO.METASTORE, + DatasetSymlinkGroupDTO( + members=[ + (DATASETS["hive_raw_user_metrics"], DatasetSymlinkTypeDTO.METASTORE), + (DATASETS["hdfs_raw_user_metrics"], DatasetSymlinkTypeDTO.WAREHOUSE), + ], ), - DatasetSymlinkDTO( - from_dataset=DATASETS["hive_mart_user_metrics_agg"], - to_dataset=DATASETS["hdfs_mart_user_metrics_agg"], - type=DatasetSymlinkTypeDTO.WAREHOUSE, + DatasetSymlinkGroupDTO( + members=[ + (DATASETS["hive_ref_user_info"], DatasetSymlinkTypeDTO.METASTORE), + (DATASETS["hdfs_ref_user_info"], DatasetSymlinkTypeDTO.WAREHOUSE), + ], ), - DatasetSymlinkDTO( - from_dataset=DATASETS["hdfs_mart_user_metrics_agg"], - to_dataset=DATASETS["hive_mart_user_metrics_agg"], - type=DatasetSymlinkTypeDTO.METASTORE, + DatasetSymlinkGroupDTO( + members=[ + (DATASETS["hive_mart_user_metrics_agg"], DatasetSymlinkTypeDTO.METASTORE), + (DATASETS["hdfs_mart_user_metrics_agg"], DatasetSymlinkTypeDTO.WAREHOUSE), + ], ), ] @@ -245,7 +233,7 @@ def generate_spark_run_yarn( result.add_run(run) for symlink in DATASET_SYMLINKS: - result.add_dataset_symlink(symlink) + result.add_dataset_symlink_group(symlink) for generator in [raw_to_mart]: operation, inputs, outputs, column_lineage = generator(faker, run) diff --git a/data_rentgen/dto/__init__.py b/data_rentgen/dto/__init__.py index 5b1b46bd..907c4a04 100644 --- a/data_rentgen/dto/__init__.py +++ b/data_rentgen/dto/__init__.py @@ -8,7 +8,11 @@ DatasetColumnRelationDTO, DatasetColumnRelationTypeDTO, ) -from data_rentgen.dto.dataset_symlink import DatasetSymlinkDTO, DatasetSymlinkTypeDTO +from data_rentgen.dto.dataset_symlink import ( + DatasetSymlinkGroupDTO, + DatasetSymlinkTypeDTO, + compute_symlink_fingerprint, +) from data_rentgen.dto.input import InputDTO from data_rentgen.dto.job import JobDTO from data_rentgen.dto.job_dependency import JobDependencyDTO @@ -33,7 +37,7 @@ "DatasetColumnRelationDTO", "DatasetColumnRelationTypeDTO", "DatasetDTO", - "DatasetSymlinkDTO", + "DatasetSymlinkGroupDTO", "DatasetSymlinkTypeDTO", "InputDTO", "JobDTO", @@ -54,4 +58,5 @@ "TagDTO", "TagValueDTO", "UserDTO", + "compute_symlink_fingerprint", ] diff --git a/data_rentgen/dto/base.py b/data_rentgen/dto/base.py index 23677e43..270b3212 100644 --- a/data_rentgen/dto/base.py +++ b/data_rentgen/dto/base.py @@ -1,12 +1,13 @@ # SPDX-FileCopyrightText: 2025-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Hashable from typing import Protocol, Self class DTO(Protocol): @property - def unique_key(self) -> tuple: + def unique_key(self) -> Hashable: """Expected to return the same value for the same DTO object""" ... diff --git a/data_rentgen/dto/dataset_symlink.py b/data_rentgen/dto/dataset_symlink.py index 7eafacc3..f3b99ab8 100644 --- a/data_rentgen/dto/dataset_symlink.py +++ b/data_rentgen/dto/dataset_symlink.py @@ -3,10 +3,14 @@ from __future__ import annotations -from dataclasses import dataclass, field +import json +from collections.abc import Iterable +from dataclasses import dataclass from enum import Enum +from uuid import UUID from data_rentgen.dto.dataset import DatasetDTO +from data_rentgen.utils.uuid import generate_static_uuid class DatasetSymlinkTypeDTO(str, Enum): @@ -17,19 +21,29 @@ def __str__(self) -> str: return self.value +DatasetSymlinkMemberDTO = tuple[DatasetDTO, DatasetSymlinkTypeDTO] + + +def compute_symlink_fingerprint( + members: Iterable[DatasetSymlinkMemberDTO], +) -> UUID: + normalized = sorted((dataset.unique_key, str(role)) for dataset, role in members) + return generate_static_uuid( + json.dumps(normalized, ensure_ascii=True), + ) + + @dataclass(slots=True) -class DatasetSymlinkDTO: - from_dataset: DatasetDTO - to_dataset: DatasetDTO - type: DatasetSymlinkTypeDTO - id: int | None = field(default=None, compare=False) +class DatasetSymlinkGroupDTO: + members: list[DatasetSymlinkMemberDTO] + + @property + def fingerprint(self) -> UUID: + return compute_symlink_fingerprint(self.members) @property - def unique_key(self) -> tuple: - return (self.from_dataset.unique_key, self.to_dataset.unique_key, self.type) + def unique_key(self) -> UUID: + return self.fingerprint - def merge(self, new: DatasetSymlinkDTO) -> DatasetSymlinkDTO: - self.from_dataset.merge(new.from_dataset) - self.to_dataset.merge(new.to_dataset) - self.id = new.id or self.id + def merge(self, new: DatasetSymlinkGroupDTO) -> DatasetSymlinkGroupDTO: return self diff --git a/docs/changelog/next_release/476.feature.rst b/docs/changelog/next_release/476.feature.rst new file mode 100644 index 00000000..f44d75f4 --- /dev/null +++ b/docs/changelog/next_release/476.feature.rst @@ -0,0 +1 @@ +Add dataset symlink group storage with compatibility view and group-based consumer writes. diff --git a/tests/test_consumer/test_extractors/fixtures/io_dto.py b/tests/test_consumer/test_extractors/fixtures/io_dto.py index 3714489d..7e03b661 100644 --- a/tests/test_consumer/test_extractors/fixtures/io_dto.py +++ b/tests/test_consumer/test_extractors/fixtures/io_dto.py @@ -2,7 +2,7 @@ from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, DatasetSymlinkTypeDTO, LocationDTO, SchemaDTO, @@ -88,50 +88,28 @@ def extracted_hive_dataset2( @pytest.fixture -def extracted_hdfs_dataset1_symlink( +def extracted_dataset1_symlink_group( extracted_hdfs_dataset1: DatasetDTO, extracted_hive_dataset1: DatasetDTO, -) -> DatasetSymlinkDTO: - return DatasetSymlinkDTO( - from_dataset=extracted_hdfs_dataset1, - to_dataset=extracted_hive_dataset1, - type=DatasetSymlinkTypeDTO.METASTORE, - ) - - -@pytest.fixture -def extracted_hive_dataset1_symlink( - extracted_hdfs_dataset1: DatasetDTO, - extracted_hive_dataset1: DatasetDTO, -) -> DatasetSymlinkDTO: - return DatasetSymlinkDTO( - from_dataset=extracted_hive_dataset1, - to_dataset=extracted_hdfs_dataset1, - type=DatasetSymlinkTypeDTO.WAREHOUSE, - ) - - -@pytest.fixture -def extracted_hdfs_dataset2_symlink( - extracted_hdfs_dataset2: DatasetDTO, - extracted_hive_dataset2: DatasetDTO, -) -> DatasetSymlinkDTO: - return DatasetSymlinkDTO( - from_dataset=extracted_hdfs_dataset2, - to_dataset=extracted_hive_dataset2, - type=DatasetSymlinkTypeDTO.METASTORE, +) -> DatasetSymlinkGroupDTO: + return DatasetSymlinkGroupDTO( + members=[ + (extracted_hdfs_dataset1, DatasetSymlinkTypeDTO.WAREHOUSE), + (extracted_hive_dataset1, DatasetSymlinkTypeDTO.METASTORE), + ], ) @pytest.fixture -def extracted_hive_dataset2_symlink( +def extracted_dataset2_symlink_group( extracted_hdfs_dataset2: DatasetDTO, extracted_hive_dataset2: DatasetDTO, -) -> DatasetSymlinkDTO: - return DatasetSymlinkDTO( - from_dataset=extracted_hive_dataset2, - to_dataset=extracted_hdfs_dataset2, - type=DatasetSymlinkTypeDTO.WAREHOUSE, +) -> DatasetSymlinkGroupDTO: + return DatasetSymlinkGroupDTO( + members=[ + (extracted_hdfs_dataset2, DatasetSymlinkTypeDTO.WAREHOUSE), + (extracted_hive_dataset2, DatasetSymlinkTypeDTO.METASTORE), + ], ) diff --git a/tests/test_consumer/test_extractors/test_extractors_batch_airflow.py b/tests/test_consumer/test_extractors/test_extractors_batch_airflow.py index bca431c1..aedd30f7 100644 --- a/tests/test_consumer/test_extractors/test_extractors_batch_airflow.py +++ b/tests/test_consumer/test_extractors/test_extractors_batch_airflow.py @@ -3,7 +3,7 @@ from data_rentgen.consumer.extractors import BatchExtractor from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, InputDTO, JobDTO, LocationDTO, @@ -64,7 +64,7 @@ def test_extractors_extract_batch_airflow_without_lineage( assert extracted.operations() == [extracted_airflow_task1_operation] assert not extracted.datasets() - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() assert not extracted.schemas() assert not extracted.inputs() assert not extracted.outputs() @@ -104,8 +104,7 @@ def test_extractors_extract_batch_airflow_with_lineage( extracted_postgres_dataset: DatasetDTO, extracted_hdfs_dataset1: DatasetDTO, extracted_hive_dataset1: DatasetDTO, - extracted_hdfs_dataset1_symlink: DatasetSymlinkDTO, - extracted_hive_dataset1_symlink: DatasetSymlinkDTO, + extracted_dataset1_symlink_group: DatasetSymlinkGroupDTO, extracted_dataset_schema: SchemaDTO, extracted_user: UserDTO, extracted_airflow_postgres_input: InputDTO, @@ -149,10 +148,7 @@ def test_extractors_extract_batch_airflow_with_lineage( extracted_postgres_dataset, ] - assert extracted.dataset_symlinks() == [ - extracted_hdfs_dataset1_symlink, - extracted_hive_dataset1_symlink, - ] + assert extracted.dataset_symlink_groups() == [extracted_dataset1_symlink_group] # Both input & output schemas are the same assert extracted.schemas() == [extracted_dataset_schema] diff --git a/tests/test_consumer/test_extractors/test_extractors_batch_dbt.py b/tests/test_consumer/test_extractors/test_extractors_batch_dbt.py index 55676b56..c7eb5293 100644 --- a/tests/test_consumer/test_extractors/test_extractors_batch_dbt.py +++ b/tests/test_consumer/test_extractors/test_extractors_batch_dbt.py @@ -74,7 +74,7 @@ def test_extractors_extract_batch_dbt_spark_thrift( extracted_dbt_spark_target_dataset, ] - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() assert extracted.schemas() == [extracted_dbt_spark_source_schema] assert extracted.inputs() == [extracted_dbt_spark_input] diff --git a/tests/test_consumer/test_extractors/test_extractors_batch_flink.py b/tests/test_consumer/test_extractors/test_extractors_batch_flink.py index 84794a6e..df47a21c 100644 --- a/tests/test_consumer/test_extractors/test_extractors_batch_flink.py +++ b/tests/test_consumer/test_extractors/test_extractors_batch_flink.py @@ -86,7 +86,7 @@ def test_extractors_extract_batch_flink( extracted_postgres_dataset, ] - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() # Both input & output schemas are the same assert extracted.schemas() == [extracted_dataset_schema] diff --git a/tests/test_consumer/test_extractors/test_extractors_batch_hive.py b/tests/test_consumer/test_extractors/test_extractors_batch_hive.py index a7301b06..a0630d68 100644 --- a/tests/test_consumer/test_extractors/test_extractors_batch_hive.py +++ b/tests/test_consumer/test_extractors/test_extractors_batch_hive.py @@ -3,7 +3,7 @@ from data_rentgen.consumer.extractors import BatchExtractor from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, InputDTO, JobDTO, LocationDTO, @@ -47,10 +47,8 @@ def test_extractors_extract_batch_hive( extracted_hive_dataset2: DatasetDTO, extracted_hdfs_dataset1: DatasetDTO, extracted_hdfs_dataset2: DatasetDTO, - extracted_hdfs_dataset1_symlink: DatasetSymlinkDTO, - extracted_hive_dataset1_symlink: DatasetSymlinkDTO, - extracted_hdfs_dataset2_symlink: DatasetSymlinkDTO, - extracted_hive_dataset2_symlink: DatasetSymlinkDTO, + extracted_dataset1_symlink_group: DatasetSymlinkGroupDTO, + extracted_dataset2_symlink_group: DatasetSymlinkGroupDTO, extracted_dataset_schema: SchemaDTO, extracted_hive_job: JobDTO, extracted_hive_run: RunDTO, @@ -89,11 +87,9 @@ def test_extractors_extract_batch_hive( extracted_hive_dataset2, ] - assert extracted.dataset_symlinks() == [ - extracted_hdfs_dataset1_symlink, - extracted_hdfs_dataset2_symlink, - extracted_hive_dataset1_symlink, - extracted_hive_dataset2_symlink, + assert [group.fingerprint for group in extracted.dataset_symlink_groups()] == [ + extracted_dataset2_symlink_group.fingerprint, + extracted_dataset1_symlink_group.fingerprint, ] # Both input & output schemas are the same diff --git a/tests/test_consumer/test_extractors/test_extractors_batch_spark.py b/tests/test_consumer/test_extractors/test_extractors_batch_spark.py index c386b194..50d39dc1 100644 --- a/tests/test_consumer/test_extractors/test_extractors_batch_spark.py +++ b/tests/test_consumer/test_extractors/test_extractors_batch_spark.py @@ -6,7 +6,7 @@ from data_rentgen.consumer.extractors import BatchExtractor from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, InputDTO, JobDTO, LocationDTO, @@ -86,7 +86,7 @@ def test_extractors_extract_batch_spark_without_lineage( assert extracted.operations() == [extracted_spark_operation] assert not extracted.datasets() - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() assert not extracted.schemas() assert not extracted.inputs() assert not extracted.outputs() @@ -126,7 +126,7 @@ def test_extractors_extract_batch_spark_openlineage_emitted_unknown_name( assert extracted.operations() == [extracted_spark_operation] assert not extracted.datasets() - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() assert not extracted.schemas() assert not extracted.inputs() assert not extracted.outputs() @@ -193,7 +193,7 @@ def test_extractors_extract_batch_spark_openlineage_emitted_unknown_name_no_job_ assert extracted.operations() == [extracted_spark_operation_with_parent] assert not extracted.datasets() - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() assert not extracted.schemas() assert not extracted.inputs() assert not extracted.outputs() @@ -229,8 +229,7 @@ def test_extractors_extract_batch_spark_with_lineage( extracted_postgres_dataset: DatasetDTO, extracted_hdfs_dataset1: DatasetDTO, extracted_hive_dataset1: DatasetDTO, - extracted_hdfs_dataset1_symlink: DatasetSymlinkDTO, - extracted_hive_dataset1_symlink: DatasetSymlinkDTO, + extracted_dataset1_symlink_group: DatasetSymlinkGroupDTO, extracted_dataset_schema: SchemaDTO, extracted_spark_app_job: JobDTO, extracted_user: UserDTO, @@ -283,10 +282,7 @@ def test_extractors_extract_batch_spark_with_lineage( extracted_postgres_dataset, ] - assert extracted.dataset_symlinks() == [ - extracted_hdfs_dataset1_symlink, - extracted_hive_dataset1_symlink, - ] + assert extracted.dataset_symlink_groups() == [extracted_dataset1_symlink_group] # Both input & output schemas are the same assert extracted.schemas() == [extracted_dataset_schema] diff --git a/tests/test_consumer/test_extractors/test_extractors_batch_starrocks.py b/tests/test_consumer/test_extractors/test_extractors_batch_starrocks.py index 5c2c6798..0b0d8750 100644 --- a/tests/test_consumer/test_extractors/test_extractors_batch_starrocks.py +++ b/tests/test_consumer/test_extractors/test_extractors_batch_starrocks.py @@ -78,7 +78,7 @@ def test_extractors_extract_batch_starrocks( extracted_iceberg_dataset2, ] - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() assert not extracted.schemas() assert extracted.inputs() == [extracted_starrocks_input] diff --git a/tests/test_consumer/test_extractors/test_extractors_batch_unknown.py b/tests/test_consumer/test_extractors/test_extractors_batch_unknown.py index 53b8f156..f54f4dcd 100644 --- a/tests/test_consumer/test_extractors/test_extractors_batch_unknown.py +++ b/tests/test_consumer/test_extractors/test_extractors_batch_unknown.py @@ -3,7 +3,7 @@ from data_rentgen.consumer.extractors import BatchExtractor from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, InputDTO, JobDTO, LocationDTO, @@ -56,7 +56,7 @@ def test_extractors_extract_batch_unknown_without_lineage( assert not extracted.operations() assert not extracted.datasets() - assert not extracted.dataset_symlinks() + assert not extracted.dataset_symlink_groups() assert not extracted.schemas() assert not extracted.inputs() assert not extracted.outputs() @@ -90,8 +90,7 @@ def test_extractors_extract_batch_unknown_with_lineage( extracted_postgres_dataset: DatasetDTO, extracted_hdfs_dataset1: DatasetDTO, extracted_hive_dataset1: DatasetDTO, - extracted_hdfs_dataset1_symlink: DatasetSymlinkDTO, - extracted_hive_dataset1_symlink: DatasetSymlinkDTO, + extracted_dataset1_symlink_group: DatasetSymlinkGroupDTO, extracted_dataset_schema: SchemaDTO, extracted_unknown_job: JobDTO, extracted_unknown_run: RunDTO, @@ -141,10 +140,7 @@ def test_extractors_extract_batch_unknown_with_lineage( extracted_postgres_dataset, ] - assert extracted.dataset_symlinks() == [ - extracted_hdfs_dataset1_symlink, - extracted_hive_dataset1_symlink, - ] + assert extracted.dataset_symlink_groups() == [extracted_dataset1_symlink_group] # Both input & output schemas are the same assert extracted.schemas() == [extracted_dataset_schema] diff --git a/tests/test_consumer/test_extractors/test_extractors_dataset.py b/tests/test_consumer/test_extractors/test_extractors_dataset.py index 0db88807..7d91b5e0 100644 --- a/tests/test_consumer/test_extractors/test_extractors_dataset.py +++ b/tests/test_consumer/test_extractors/test_extractors_dataset.py @@ -2,7 +2,7 @@ from data_rentgen.consumer.extractors.impl import DbtExtractor, FlinkExtractor, SparkExtractor from data_rentgen.dto import ( DatasetDTO, - DatasetSymlinkDTO, + DatasetSymlinkGroupDTO, DatasetSymlinkTypeDTO, LocationDTO, TagDTO, @@ -95,8 +95,12 @@ def test_extractors_extract_dataset_hdfs_with_table_symlink(): dataset_dto, symlinks_dto = SparkExtractor().extract_dataset_and_symlinks(dataset) assert dataset_dto == hive_dataset assert symlinks_dto == [ - DatasetSymlinkDTO(from_dataset=hdfs_dataset, to_dataset=hive_dataset, type=DatasetSymlinkTypeDTO.METASTORE), - DatasetSymlinkDTO(from_dataset=hive_dataset, to_dataset=hdfs_dataset, type=DatasetSymlinkTypeDTO.WAREHOUSE), + DatasetSymlinkGroupDTO( + members=[ + (hdfs_dataset, DatasetSymlinkTypeDTO.WAREHOUSE), + (hive_dataset, DatasetSymlinkTypeDTO.METASTORE), + ], + ), ] @@ -192,8 +196,12 @@ def test_extractors_extract_dataset_hive_with_location_symlink(): dataset_dto, symlinks_dto = SparkExtractor().extract_dataset_and_symlinks(dataset) assert dataset_dto == hive_dataset assert symlinks_dto == [ - DatasetSymlinkDTO(from_dataset=hdfs_dataset, to_dataset=hive_dataset, type=DatasetSymlinkTypeDTO.METASTORE), - DatasetSymlinkDTO(from_dataset=hive_dataset, to_dataset=hdfs_dataset, type=DatasetSymlinkTypeDTO.WAREHOUSE), + DatasetSymlinkGroupDTO( + members=[ + (hive_dataset, DatasetSymlinkTypeDTO.METASTORE), + (hdfs_dataset, DatasetSymlinkTypeDTO.WAREHOUSE), + ], + ), ] diff --git a/tests/test_database/test_migration_dataset_symlink_group.py b/tests/test_database/test_migration_dataset_symlink_group.py new file mode 100644 index 00000000..f6c64d28 --- /dev/null +++ b/tests/test_database/test_migration_dataset_symlink_group.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: 2024-present MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.pool import NullPool + +from data_rentgen.db.models import Base +from data_rentgen.dto import ( + DatasetDTO, + DatasetSymlinkTypeDTO, + LocationDTO, + compute_symlink_fingerprint, +) +from tests.test_database.fixtures.alembic import do_run_migrations + +if TYPE_CHECKING: + from alembic.config import Config as AlembicConfig + +pytestmark = [pytest.mark.db] + +PREV_REVISION = "c3f8a2e1d749" +THIS_REVISION = "4a02d2d5c8b1" + + +def _dataset(location_type: str, location_name: str, name: str) -> DatasetDTO: + return DatasetDTO( + name=name, + location=LocationDTO(type=location_type, name=location_name, addresses=set()), + ) + + +def test_migration_backfill_dataset_symlink_group(empty_db_url: str, alembic_config: AlembicConfig): + do_run_migrations(alembic_config, Base.metadata, PREV_REVISION) + + engine = create_engine(empty_db_url, poolclass=NullPool) + try: + with engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO location (id, type, name) VALUES + (1, 'hive', 'metastore'), + (2, 'hdfs', 'cluster') + """, + ), + ) + conn.execute( + text( + """ + INSERT INTO dataset (id, location_id, name) VALUES + (10, 1, 'schema.table1'), + (11, 1, 'schema.table2'), + (20, 2, '/warehouse/table') + """, + ), + ) + conn.execute( + text( + """ + INSERT INTO dataset_symlink (id, from_dataset_id, to_dataset_id, type) VALUES + (1, 20, 10, 'METASTORE'), + (2, 10, 20, 'WAREHOUSE'), + (3, 20, 11, 'METASTORE'), + (4, 11, 20, 'WAREHOUSE') + """, + ), + ) + + do_run_migrations(alembic_config, Base.metadata, THIS_REVISION) + + with engine.connect() as conn: + rows = conn.execute( + text("SELECT fingerprint::text, dataset_id, type FROM dataset_symlink_group"), + ).fetchall() + + hdfs = _dataset("hdfs", "cluster", "/warehouse/table") + hive1 = _dataset("hive", "metastore", "schema.table1") + hive2 = _dataset("hive", "metastore", "schema.table2") + + fingerprint_a = compute_symlink_fingerprint( + [(hdfs, DatasetSymlinkTypeDTO.WAREHOUSE), (hive1, DatasetSymlinkTypeDTO.METASTORE)], + ) + fingerprint_b = compute_symlink_fingerprint( + [(hdfs, DatasetSymlinkTypeDTO.WAREHOUSE), (hive2, DatasetSymlinkTypeDTO.METASTORE)], + ) + + actual = {(fingerprint, dataset_id, type_) for fingerprint, dataset_id, type_ in rows} + assert actual == { + (str(fingerprint_a), 20, "WAREHOUSE"), + (str(fingerprint_a), 10, "METASTORE"), + (str(fingerprint_b), 20, "WAREHOUSE"), + (str(fingerprint_b), 11, "METASTORE"), + } + + assert fingerprint_a != fingerprint_b + fingerprints_of_hdfs = {fingerprint for fingerprint, dataset_id, _ in rows if dataset_id == 20} + assert fingerprints_of_hdfs == {str(fingerprint_a), str(fingerprint_b)} + + # The old dataset_symlink table is replaced by a VIEW + with engine.connect() as conn: + view_rows = conn.execute( + text( + "SELECT from_dataset_id, to_dataset_id, type FROM dataset_symlink ORDER BY from_dataset_id, to_dataset_id" + ), + ).fetchall() + + assert set(view_rows) == { + (10, 20, "WAREHOUSE"), + (20, 10, "METASTORE"), + (11, 20, "WAREHOUSE"), + (20, 11, "METASTORE"), + } + + do_run_migrations(alembic_config, Base.metadata, PREV_REVISION) + + with engine.connect() as conn: + downgraded_rows = conn.execute( + text( + "SELECT from_dataset_id, to_dataset_id, type FROM dataset_symlink ORDER BY from_dataset_id, to_dataset_id" + ), + ).fetchall() + + assert set(downgraded_rows) == { + (10, 20, "WAREHOUSE"), + (20, 10, "METASTORE"), + (11, 20, "WAREHOUSE"), + (20, 11, "METASTORE"), + } + + finally: + engine.dispose() diff --git a/tests/test_server/fixtures/factories/dataset.py b/tests/test_server/fixtures/factories/dataset.py index 59ff4090..d688f6b4 100644 --- a/tests/test_server/fixtures/factories/dataset.py +++ b/tests/test_server/fixtures/factories/dataset.py @@ -4,9 +4,12 @@ from typing import TYPE_CHECKING import pytest_asyncio +from sqlalchemy.dialects.postgresql import insert from data_rentgen.db.models import Dataset, TagValue from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType +from data_rentgen.db.models.dataset_symlink_group import DatasetSymlinkGroup +from data_rentgen.utils.uuid import generate_static_uuid from tests.test_server.fixtures.factories.base import random_string from tests.test_server.fixtures.factories.location import create_location from tests.test_server.utils.delete import clean_db @@ -58,15 +61,34 @@ async def make_symlink( to_dataset: Dataset, type: DatasetSymlinkType, ) -> DatasetSymlink: - symlink = DatasetSymlink( + await make_symlink_group(async_session, from_dataset, to_dataset, type) + return DatasetSymlink( from_dataset_id=from_dataset.id, to_dataset_id=to_dataset.id, type=type, ) - async_session.add(symlink) + + +async def make_symlink_group( + async_session: AsyncSession, + from_dataset: Dataset, + to_dataset: Dataset, + type: DatasetSymlinkType, +) -> None: + left, right = sorted((from_dataset.id, to_dataset.id)) + fingerprint = generate_static_uuid(f"symlink_group:{left}:{right}") + opposite = DatasetSymlinkType.WAREHOUSE if type == DatasetSymlinkType.METASTORE else DatasetSymlinkType.METASTORE + statement = insert(DatasetSymlinkGroup).on_conflict_do_nothing( + index_elements=[DatasetSymlinkGroup.dataset_id, DatasetSymlinkGroup.fingerprint], + ) + await async_session.execute( + statement, + [ + {"fingerprint": fingerprint, "dataset_id": from_dataset.id, "type": opposite}, + {"fingerprint": fingerprint, "dataset_id": to_dataset.id, "type": type}, + ], + ) await async_session.commit() - await async_session.refresh(symlink) - return symlink @pytest_asyncio.fixture(params=[{}]) diff --git a/tests/test_server/utils/delete.py b/tests/test_server/utils/delete.py index 3c6a2d0a..5f8e8811 100644 --- a/tests/test_server/utils/delete.py +++ b/tests/test_server/utils/delete.py @@ -4,7 +4,7 @@ from data_rentgen.db.models import ( Address, Dataset, - DatasetSymlink, + DatasetSymlinkGroup, DatasetTagValue, Input, Job, @@ -22,7 +22,7 @@ async def clean_db(async_session: AsyncSession) -> None: await async_session.execute(delete(DatasetTagValue)) await async_session.execute(delete(Location)) await async_session.execute(delete(Address)) - await async_session.execute(delete(DatasetSymlink)) + await async_session.execute(delete(DatasetSymlinkGroup)) await async_session.execute(delete(TagValue)) await async_session.execute(delete(Tag)) await async_session.execute(delete(Dataset))