From 35e5bed579decfadb0c6067952b3de066ff80cbe Mon Sep 17 00:00:00 2001 From: Ilyas Gasanov Date: Fri, 3 Jul 2026 17:45:14 +0300 Subject: [PATCH] [DOP-38392] Use recursive symlink queries --- ...02d2d5c8b1_create_dataset_symlink_group.py | 5 + .../db/models/dataset_symlink_group.py | 2 +- .../db/repositories/dataset_symlink.py | 35 +++++-- .../db/repositories/job_dependency.py | 60 +++++++---- data_rentgen/server/services/lineage.py | 55 +++++++---- data_rentgen/server/utils/lineage_response.py | 5 +- .../fixtures/factories/job_dependencies.py | 7 ++ .../test_server/fixtures/factories/lineage.py | 99 +++++++++++++++++++ .../test_lineage/test_dataset_lineage.py | 76 +++++++++++++- 9 files changed, 294 insertions(+), 50 deletions(-) diff --git a/data_rentgen/db/migrations/versions/2026-06-29_4a02d2d5c8b1_create_dataset_symlink_group.py b/data_rentgen/db/migrations/versions/2026-06-29_4a02d2d5c8b1_create_dataset_symlink_group.py index 4a5d8d19..283a8ed8 100644 --- a/data_rentgen/db/migrations/versions/2026-06-29_4a02d2d5c8b1_create_dataset_symlink_group.py +++ b/data_rentgen/db/migrations/versions/2026-06-29_4a02d2d5c8b1_create_dataset_symlink_group.py @@ -46,6 +46,11 @@ def upgrade() -> None: ), ) _backfill_symlink_groups() + op.create_index( + op.f("ix__dataset_symlink_group__fingerprint"), + "dataset_symlink_group", + ["fingerprint"], + ) op.drop_table("dataset_symlink") op.execute( sa.text( diff --git a/data_rentgen/db/models/dataset_symlink_group.py b/data_rentgen/db/models/dataset_symlink_group.py index 591f8b4e..fdbbb8d8 100644 --- a/data_rentgen/db/models/dataset_symlink_group.py +++ b/data_rentgen/db/models/dataset_symlink_group.py @@ -26,7 +26,7 @@ class DatasetSymlinkGroup(Base): lazy="noload", foreign_keys=[dataset_id], ) - fingerprint: Mapped[UUID] = mapped_column(SQL_UUID, primary_key=True) + fingerprint: Mapped[UUID] = mapped_column(SQL_UUID, primary_key=True, index=True) type: Mapped[DatasetSymlinkType] = mapped_column( ChoiceType(DatasetSymlinkType, impl=String(32)), nullable=False, diff --git a/data_rentgen/db/repositories/dataset_symlink.py b/data_rentgen/db/repositories/dataset_symlink.py index 19598367..31e64d42 100644 --- a/data_rentgen/db/repositories/dataset_symlink.py +++ b/data_rentgen/db/repositories/dataset_symlink.py @@ -3,10 +3,11 @@ from collections.abc import Collection -from sqlalchemy import any_, bindparam, or_, select +from sqlalchemy import ARRAY, BigInteger, bindparam, func, select from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import aliased -from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType +from data_rentgen.db.models.dataset_symlink import DatasetSymlinkType from data_rentgen.db.models.dataset_symlink_group import DatasetSymlinkGroup from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import DatasetSymlinkGroupDTO @@ -15,11 +16,24 @@ index_elements=[DatasetSymlinkGroup.dataset_id, DatasetSymlinkGroup.fingerprint], ) -get_list_query = select(DatasetSymlink).where( - or_( - DatasetSymlink.from_dataset_id == any_(bindparam("dataset_ids")), - DatasetSymlink.to_dataset_id == any_(bindparam("dataset_ids")), - ), +group_member = aliased(DatasetSymlinkGroup, name="group_member") +neighbour_group_member = aliased(DatasetSymlinkGroup, name="neighbour_group_member") + +closure_base_part = select( + func.unnest(bindparam("dataset_ids", type_=ARRAY(BigInteger()))).label("dataset_id"), +) +closure_cte = closure_base_part.cte("reachable_datasets", recursive=True) +closure_recursive_part = ( + select(neighbour_group_member.dataset_id.label("dataset_id")) + .select_from(closure_cte) + .join(group_member, group_member.dataset_id == closure_cte.c.dataset_id) + .join(neighbour_group_member, neighbour_group_member.fingerprint == group_member.fingerprint) +) +closure_cte = closure_cte.union(closure_recursive_part) + +get_symlink_groups_query = select(DatasetSymlinkGroup).join( + closure_cte, + DatasetSymlinkGroup.dataset_id == closure_cte.c.dataset_id, ) @@ -41,9 +55,12 @@ async def create_bulk(self, items: list[DatasetSymlinkGroupDTO]): ], ) - async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[DatasetSymlink]: + async def get_symlink_groups(self, dataset_ids: Collection[int]) -> list[DatasetSymlinkGroup]: if not dataset_ids: return [] - scalars = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)}) + scalars = await self._session.scalars( + get_symlink_groups_query, + {"dataset_ids": list(dataset_ids)}, + ) return list(scalars.all()) diff --git a/data_rentgen/db/repositories/job_dependency.py b/data_rentgen/db/repositories/job_dependency.py index b30bd79a..979ed13b 100644 --- a/data_rentgen/db/repositories/job_dependency.py +++ b/data_rentgen/db/repositories/job_dependency.py @@ -19,14 +19,47 @@ select, tuple_, ) +from sqlalchemy.orm import aliased -from data_rentgen.db.models.dataset_symlink import DatasetSymlink +from data_rentgen.db.models.dataset_symlink_group import DatasetSymlinkGroup from data_rentgen.db.models.input import Input from data_rentgen.db.models.job_dependency import JobDependency from data_rentgen.db.models.output import Output from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import JobDependencyDTO + +def _symlink_connected_cte(): + output = aliased(Output, name="connected_output") + base_part = ( + select( + output.dataset_id.label("original_dataset_id"), + output.dataset_id.label("dataset_id_via_symlink"), + ) + .where( + output.created_at >= bindparam("since"), + or_( + bindparam("until", type_=DateTime(timezone=True)).is_(None), + output.created_at <= bindparam("until"), + ), + ) + .distinct() + ) + cte = base_part.cte("symlink_connected", recursive=True) + reached_group = aliased(DatasetSymlinkGroup, name="connected_reached_group") + next_group = aliased(DatasetSymlinkGroup, name="connected_next_group") + recursive_part = ( + select( + cte.c.original_dataset_id.label("original_dataset_id"), + next_group.dataset_id.label("dataset_id_via_symlink"), + ) + .select_from(cte) + .join(reached_group, reached_group.dataset_id == cte.c.dataset_id_via_symlink) + .join(next_group, next_group.fingerprint == reached_group.fingerprint) + ) + return cte.union(recursive_part) + + fetch_bulk_query = select(JobDependency).where( tuple_(JobDependency.from_job_id, JobDependency.to_job_id).in_( select( @@ -163,25 +196,14 @@ def _get_core_hierarchy_query( Input, Output.dataset_id == Input.dataset_id, ).where(*where_clauses) - # IO connections Output.d_id == Symlink.to_d_id Symlink.from_d_id == Input.d_id - via_symlinks_from_output = ( - inferred_columns.join(DatasetSymlink, Output.dataset_id == DatasetSymlink.to_dataset_id) - .join( - Input, - DatasetSymlink.from_dataset_id == Input.dataset_id, - ) - .where(*where_clauses) - ) - # IO connections Input.d_id == Symlink.to_d_id Symlink.from_d_id == Output.d_id - via_symlinks_from_input = ( - inferred_columns.join(DatasetSymlink, Input.dataset_id == DatasetSymlink.to_dataset_id) - .join( - Output, - DatasetSymlink.from_dataset_id == Output.dataset_id, - ) - .where(*where_clauses) + # IO connections via symlinked datasets + connected = _symlink_connected_cte() + via_symlinks = ( + inferred_columns.join(connected, Output.dataset_id == connected.c.original_dataset_id) + .join(Input, connected.c.dataset_id_via_symlink == Input.dataset_id) + .where(*where_clauses, connected.c.original_dataset_id != connected.c.dataset_id_via_symlink) ) - query = query.union(direct_connection, via_symlinks_from_input, via_symlinks_from_output) + query = query.union(direct_connection, via_symlinks) return query diff --git a/data_rentgen/server/services/lineage.py b/data_rentgen/server/services/lineage.py index da351cc5..efff873d 100644 --- a/data_rentgen/server/services/lineage.py +++ b/data_rentgen/server/services/lineage.py @@ -6,14 +6,15 @@ from collections.abc import Collection from dataclasses import dataclass, field from datetime import datetime -from typing import Annotated, Literal +from typing import Annotated, Literal, NamedTuple from uuid import UUID from fastapi import Depends from data_rentgen.db.models import ( Dataset, - DatasetSymlink, + DatasetSymlinkGroup, + DatasetSymlinkType, Job, Operation, Run, @@ -28,13 +29,37 @@ logger = logging.getLogger(__name__) +class SymlinkPair(NamedTuple): + from_dataset_id: int + to_dataset_id: int + type: DatasetSymlinkType + + +def _reconstruct_symlink_pairs(symlink_groups: list[DatasetSymlinkGroup]) -> list[SymlinkPair]: + members_by_fingerprint: dict[UUID, list[tuple[int, DatasetSymlinkType]]] = defaultdict(list) + for row in symlink_groups: + members_by_fingerprint[row.fingerprint].append((row.dataset_id, row.type)) + + pairs: dict[tuple[int, int], SymlinkPair] = {} + for members in members_by_fingerprint.values(): + for from_dataset_id, _ in members: + for to_dataset_id, to_type in members: + if from_dataset_id != to_dataset_id: + pairs[(from_dataset_id, to_dataset_id)] = SymlinkPair( + from_dataset_id=from_dataset_id, + to_dataset_id=to_dataset_id, + type=to_type, + ) + return list(pairs.values()) + + @dataclass class LineageServiceResult: jobs: dict[int, Job] = field(default_factory=dict) runs: dict[UUID, Run] = field(default_factory=dict) operations: dict[UUID, Operation] = field(default_factory=dict) datasets: dict[int, Dataset] = field(default_factory=dict) - dataset_symlinks: dict[tuple[int, int], DatasetSymlink] = field(default_factory=dict) + dataset_symlinks: dict[tuple[int, int], SymlinkPair] = field(default_factory=dict) inputs: dict[tuple[int, int, UUID | None, UUID | None], InputRow] = field(default_factory=dict) outputs: dict[tuple[int, int, UUID | None, UUID | None, int | None], OutputRow] = field(default_factory=dict) column_lineage: dict[tuple[int, int], list[ColumnLineageRow]] = field(default_factory=dict) @@ -682,7 +707,7 @@ async def get_lineage_by_datasets( # noqa: C901 datasets_by_id = {dataset.id: dataset for dataset in datasets} - # Threat dataset symlinks like they are specified in `start_node_ids` + # Treat dataset symlinks like they are specified in `start_node_ids` ids_to_skip = ids_to_skip or IdsToSkip() extra_dataset_ids, dataset_symlinks = await self._resolve_dataset_ids_via_symlink(datasets_by_id, ids_to_skip) extra_datasets = await self._uow.dataset.list_by_ids(extra_dataset_ids) @@ -775,22 +800,18 @@ async def _resolve_dataset_ids_via_symlink( self, dataset_ids: Collection[int], ids_to_skip: IdsToSkip | None = None, - ) -> tuple[set[int], list[DatasetSymlink]]: - # For now return all symlinks regardless of direction - dataset_symlinks = await self._uow.dataset_symlink.list_by_dataset_ids(dataset_ids) + ) -> tuple[set[int], list[SymlinkPair]]: + symlink_groups = await self._uow.dataset_symlink.get_symlink_groups(dataset_ids) ids_to_skip = ids_to_skip or IdsToSkip() - dataset_ids_from_symlinks = {dataset_symlink.from_dataset_id for dataset_symlink in dataset_symlinks} - dataset_ids_to_symlinks = {dataset_symlink.to_dataset_id for dataset_symlink in dataset_symlinks} - new_dataset_ids = ( - (dataset_ids_from_symlinks | dataset_ids_to_symlinks) - set(dataset_ids) - ids_to_skip.datasets - ) - return new_dataset_ids, dataset_symlinks + symlink_dataset_ids = {row.dataset_id for row in symlink_groups} + new_dataset_ids = symlink_dataset_ids - set(dataset_ids) - ids_to_skip.datasets + return new_dataset_ids, _reconstruct_symlink_pairs(symlink_groups) async def _dataset_lineage_with_operation_granularity( self, datasets_by_id: dict[int, Dataset], - dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink], + dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair], direction: LineageDirectionV1, since: datetime, until: datetime | None, @@ -931,7 +952,7 @@ async def _dataset_lineage_with_operation_granularity( async def _dataset_lineage_with_run_granularity( self, datasets_by_id: dict[int, Dataset], - dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink], + dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair], direction: LineageDirectionV1, since: datetime, until: datetime | None, @@ -1049,7 +1070,7 @@ async def _dataset_lineage_with_run_granularity( async def _dataset_lineage_with_job_granularity( self, datasets_by_id: dict[int, Dataset], - dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink], + dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair], direction: LineageDirectionV1, since: datetime, until: datetime | None, @@ -1142,7 +1163,7 @@ async def _dataset_lineage_with_job_granularity( async def _dataset_lineage_with_dataset_granularity( # noqa: C901, PLR0915, PLR0912 self, datasets_by_id: dict[int, Dataset], - dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink], + dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair], direction: LineageDirectionV1, since: datetime, until: datetime | None, diff --git a/data_rentgen/server/utils/lineage_response.py b/data_rentgen/server/utils/lineage_response.py index 44635db8..585b29b6 100644 --- a/data_rentgen/server/utils/lineage_response.py +++ b/data_rentgen/server/utils/lineage_response.py @@ -36,12 +36,11 @@ from uuid import UUID from data_rentgen.db.models.dataset import Dataset - from data_rentgen.db.models.dataset_symlink import DatasetSymlink from data_rentgen.db.models.operation import Operation from data_rentgen.db.models.run import Run from data_rentgen.db.repositories.column_lineage import ColumnLineageRow from data_rentgen.db.repositories.io_dataset_relation import IODatasetRelationRow - from data_rentgen.server.services.lineage import LineageServiceResult + from data_rentgen.server.services.lineage import LineageServiceResult, SymlinkPair def build_lineage_response(lineage: LineageServiceResult) -> LineageResponseV1: @@ -108,7 +107,7 @@ def _get_operation_parent_relations(operations: dict[UUID, Operation]) -> list[L return parents -def _get_symlink_relations(dataset_symlinks: dict[Any, DatasetSymlink]) -> list[LineageSymlinkRelationV1]: +def _get_symlink_relations(dataset_symlinks: dict[Any, SymlinkPair]) -> list[LineageSymlinkRelationV1]: symlinks = [] for key in sorted(dataset_symlinks): dataset_symlink = dataset_symlinks[key] diff --git a/tests/test_server/fixtures/factories/job_dependencies.py b/tests/test_server/fixtures/factories/job_dependencies.py index 8117ecdc..2801caf7 100644 --- a/tests/test_server/fixtures/factories/job_dependencies.py +++ b/tests/test_server/fixtures/factories/job_dependencies.py @@ -325,10 +325,17 @@ async def job_dependency_chain_with_lineage_and_symlinks( # Create datasets connected via symlinks. left_dataset_location = await create_location(async_session) left_output_dataset = await create_dataset(async_session, location_id=left_dataset_location.id) + left_intermediate_dataset = await create_dataset(async_session, location_id=left_dataset_location.id) left_input_dataset = await create_dataset(async_session, location_id=left_dataset_location.id) await make_symlink( async_session=async_session, from_dataset=left_output_dataset, + to_dataset=left_intermediate_dataset, + type=DatasetSymlinkType.METASTORE, + ) + await make_symlink( + async_session=async_session, + from_dataset=left_intermediate_dataset, to_dataset=left_input_dataset, type=DatasetSymlinkType.METASTORE, ) diff --git a/tests/test_server/fixtures/factories/lineage.py b/tests/test_server/fixtures/factories/lineage.py index 52e36cfc..166e70e3 100644 --- a/tests/test_server/fixtures/factories/lineage.py +++ b/tests/test_server/fixtures/factories/lineage.py @@ -1749,3 +1749,102 @@ async def lineage_with_parent_run_relations( async with async_session_maker() as async_session: await clean_db(async_session) + + +@pytest_asyncio.fixture() +async def lineage_with_transitive_symlinks( + async_session_maker: Callable[[], AbstractAsyncContextManager[AsyncSession]], + user: User, +) -> AsyncGenerator[LineageResult, None]: + created_at = datetime.now(tz=UTC) + + async with async_session_maker() as async_session: + builder = LineageBuilder(async_session) + + hdfs_location = await builder.create_location( + key="transitive_symlinks_hdfs_location", + location_kwargs={"type": "hdfs"}, + ) + hive_location = await builder.create_location( + key="transitive_symlinks_hive_location", + location_kwargs={"type": "hive"}, + ) + hdfs_shared = await builder.create_dataset( + key="transitive_symlinks_hdfs_shared", + location=hdfs_location, + dataset_kwargs={"name": "/warehouse/shared"}, + ) + hive_a = await builder.create_dataset( + key="transitive_symlinks_hive_a", + location=hive_location, + dataset_kwargs={"name": "schema.table_a"}, + ) + hive_b = await builder.create_dataset( + key="transitive_symlinks_hive_b", + location=hive_location, + dataset_kwargs={"name": "schema.table_b"}, + ) + + for name, hive_dataset in (("a", hive_a), ("b", hive_b)): + await builder.create_dataset_symlink( + key=f"transitive_symlinks_metastore_{name}", + from_dataset=hdfs_shared, + to_dataset=hive_dataset, + type=DatasetSymlinkType.METASTORE, + ) + await builder.create_dataset_symlink( + key=f"transitive_symlinks_warehouse_{name}", + from_dataset=hive_dataset, + to_dataset=hdfs_shared, + type=DatasetSymlinkType.WAREHOUSE, + ) + + schema = await builder.create_schema(key="transitive_symlinks_schema") + job_type = await builder.create_job_type(key="transitive_symlinks_job_type") + job = await builder.create_job( + key="transitive_symlinks_job", + location_key="transitive_symlinks_job_location", + job_type=job_type, + ) + run = await builder.create_run( + key="transitive_symlinks_run", + job=job, + run_kwargs={ + "job_id": job.id, + "started_by_user_id": user.id, + "created_at": created_at, + }, + ) + operation = await builder.create_operation( + key="transitive_symlinks_operation", + run=run, + operation_kwargs={ + "created_at": run.created_at + timedelta(seconds=0.2), + "run_id": run.id, + }, + ) + await builder.create_output( + key="transitive_symlinks_output", + operation=operation, + run=run, + job=job, + dataset=hive_a, + output_type=OutputType.APPEND, + schema=schema, + output_kwargs={ + "created_at": operation.created_at, + "operation_id": operation.id, + "run_id": operation.run_id, + "job_id": job.id, + "dataset_id": hive_a.id, + "type": OutputType.APPEND, + "schema_id": schema.id, + }, + ) + + lineage = builder.build() + + yield lineage + + async with async_session_maker() as async_session: + await clean_db(async_session) diff --git a/tests/test_server/test_lineage/test_dataset_lineage.py b/tests/test_server/test_lineage/test_dataset_lineage.py index c1f073de..684b18c0 100644 --- a/tests/test_server/test_lineage/test_dataset_lineage.py +++ b/tests/test_server/test_lineage/test_dataset_lineage.py @@ -1285,7 +1285,7 @@ async def test_get_dataset_lineage_with_symlink( datasets = [dataset for dataset in lineage.datasets if dataset.id in dataset_ids] assert datasets - # Threat all datasets from symlinks like they were passed as `start_node_id` + # Treat all datasets from symlinks like they were passed as `start_node_id` inputs = [input for input in lineage.inputs if input.dataset_id in dataset_ids] assert inputs @@ -1343,6 +1343,80 @@ async def test_get_dataset_lineage_with_symlink( } +async def test_get_dataset_lineage_with_transitive_symlinks( + test_client: AsyncClient, + async_session: AsyncSession, + lineage_with_transitive_symlinks: LineageResult, + mocked_user: MockedUser, +): + lineage = lineage_with_transitive_symlinks + datasets_by_name = {dataset.name: dataset for dataset in lineage.datasets} + hive_a = datasets_by_name["schema.table_a"] + hdfs_shared = datasets_by_name["/warehouse/shared"] + hive_b = datasets_by_name["schema.table_b"] + dataset_ids = {hive_a.id, hdfs_shared.id, hive_b.id} + + datasets = [dataset for dataset in lineage.datasets if dataset.id in dataset_ids] + assert datasets + + dataset_symlinks = [ + dataset_symlink + for dataset_symlink in lineage.dataset_symlinks + if dataset_symlink.from_dataset_id in dataset_ids or dataset_symlink.to_dataset_id in dataset_ids + ] + assert dataset_symlinks + + inputs = [input for input in lineage.inputs if input.dataset_id in dataset_ids] + outputs = [output for output in lineage.outputs if output.dataset_id in dataset_ids] + assert outputs + + run_ids = {output.run_id for output in outputs if output.run_id is not None} + runs = [run for run in lineage.runs if run.id in run_ids] + assert runs + + job_ids = {run.job_id for run in runs} + jobs = [job for job in lineage.jobs if job.id in job_ids] + assert jobs + + datasets = await enrich_datasets(datasets, async_session) + jobs = await enrich_jobs(jobs, async_session) + runs = await enrich_runs(runs, async_session) + since = min(run.created_at for run in lineage.runs) + + response = await test_client.get( + "v1/datasets/lineage", + headers={"Authorization": f"Bearer {mocked_user.access_token}"}, + params={ + "since": since.isoformat(), + "start_node_id": hive_a.id, + }, + ) + + assert response.status_code == HTTPStatus.OK, response.json() + assert response.json() == { + "relations": { + "parents": run_parents_to_json(runs), + "symlinks": symlinks_to_json(dataset_symlinks), + "inputs": [ + *inputs_to_json(merge_io_by_jobs(inputs), granularity="JOB"), + *inputs_to_json(merge_io_by_runs(inputs), granularity="RUN"), + ], + "outputs": [ + *outputs_to_json(merge_io_by_jobs(outputs), granularity="JOB"), + *outputs_to_json(merge_io_by_runs(outputs), granularity="RUN"), + ], + "direct_column_lineage": [], + "indirect_column_lineage": [], + }, + "nodes": { + "datasets": datasets_to_json(datasets, outputs, inputs), + "jobs": jobs_to_json(jobs), + "runs": runs_to_json(runs), + "operations": {}, + }, + } + + async def test_get_dataset_lineage_with_symlink_without_input_output( test_client: AsyncClient, async_session: AsyncSession,