Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def upgrade() -> None:
),
)
_backfill_symlink_groups()
op.create_index(
op.f("ix__dataset_symlink_group__fingerprint"),
"dataset_symlink_group",
["fingerprint"],
)
op.drop_table("dataset_symlink")
op.execute(
sa.text(
Expand Down
2 changes: 1 addition & 1 deletion data_rentgen/db/models/dataset_symlink_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DatasetSymlinkGroup(Base):
lazy="noload",
foreign_keys=[dataset_id],
)
fingerprint: Mapped[UUID] = mapped_column(SQL_UUID, primary_key=True)
fingerprint: Mapped[UUID] = mapped_column(SQL_UUID, primary_key=True, index=True)
type: Mapped[DatasetSymlinkType] = mapped_column(
ChoiceType(DatasetSymlinkType, impl=String(32)),
nullable=False,
Expand Down
35 changes: 26 additions & 9 deletions data_rentgen/db/repositories/dataset_symlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from collections.abc import Collection

from sqlalchemy import any_, bindparam, or_, select
from sqlalchemy import ARRAY, BigInteger, bindparam, func, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import aliased

from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType
from data_rentgen.db.models.dataset_symlink import DatasetSymlinkType
from data_rentgen.db.models.dataset_symlink_group import DatasetSymlinkGroup
from data_rentgen.db.repositories.base import Repository
from data_rentgen.dto import DatasetSymlinkGroupDTO
Expand All @@ -15,11 +16,24 @@
index_elements=[DatasetSymlinkGroup.dataset_id, DatasetSymlinkGroup.fingerprint],
)

get_list_query = select(DatasetSymlink).where(
or_(
DatasetSymlink.from_dataset_id == any_(bindparam("dataset_ids")),
DatasetSymlink.to_dataset_id == any_(bindparam("dataset_ids")),
),
group_member = aliased(DatasetSymlinkGroup, name="group_member")
neighbour_group_member = aliased(DatasetSymlinkGroup, name="neighbour_group_member")

closure_base_part = select(
func.unnest(bindparam("dataset_ids", type_=ARRAY(BigInteger()))).label("dataset_id"),
)
closure_cte = closure_base_part.cte("reachable_datasets", recursive=True)
closure_recursive_part = (
select(neighbour_group_member.dataset_id.label("dataset_id"))
.select_from(closure_cte)
.join(group_member, group_member.dataset_id == closure_cte.c.dataset_id)
.join(neighbour_group_member, neighbour_group_member.fingerprint == group_member.fingerprint)
)
closure_cte = closure_cte.union(closure_recursive_part)

get_symlink_groups_query = select(DatasetSymlinkGroup).join(
closure_cte,
DatasetSymlinkGroup.dataset_id == closure_cte.c.dataset_id,
)


Expand All @@ -41,9 +55,12 @@ async def create_bulk(self, items: list[DatasetSymlinkGroupDTO]):
],
)

async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[DatasetSymlink]:
async def get_symlink_groups(self, dataset_ids: Collection[int]) -> list[DatasetSymlinkGroup]:
if not dataset_ids:
return []

scalars = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)})
scalars = await self._session.scalars(
get_symlink_groups_query,
{"dataset_ids": list(dataset_ids)},
)
return list(scalars.all())
60 changes: 41 additions & 19 deletions data_rentgen/db/repositories/job_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,47 @@
select,
tuple_,
)
from sqlalchemy.orm import aliased

from data_rentgen.db.models.dataset_symlink import DatasetSymlink
from data_rentgen.db.models.dataset_symlink_group import DatasetSymlinkGroup
from data_rentgen.db.models.input import Input
from data_rentgen.db.models.job_dependency import JobDependency
from data_rentgen.db.models.output import Output
from data_rentgen.db.repositories.base import Repository
from data_rentgen.dto import JobDependencyDTO


def _symlink_connected_cte():
output = aliased(Output, name="connected_output")
base_part = (
select(
output.dataset_id.label("original_dataset_id"),
output.dataset_id.label("dataset_id_via_symlink"),
)
.where(
output.created_at >= bindparam("since"),
or_(
bindparam("until", type_=DateTime(timezone=True)).is_(None),
output.created_at <= bindparam("until"),
),
)
.distinct()
)
cte = base_part.cte("symlink_connected", recursive=True)
reached_group = aliased(DatasetSymlinkGroup, name="connected_reached_group")
next_group = aliased(DatasetSymlinkGroup, name="connected_next_group")
recursive_part = (
select(
cte.c.original_dataset_id.label("original_dataset_id"),
next_group.dataset_id.label("dataset_id_via_symlink"),
)
.select_from(cte)
.join(reached_group, reached_group.dataset_id == cte.c.dataset_id_via_symlink)
.join(next_group, next_group.fingerprint == reached_group.fingerprint)
)
return cte.union(recursive_part)


fetch_bulk_query = select(JobDependency).where(
tuple_(JobDependency.from_job_id, JobDependency.to_job_id).in_(
select(
Expand Down Expand Up @@ -163,25 +196,14 @@ def _get_core_hierarchy_query(
Input,
Output.dataset_id == Input.dataset_id,
).where(*where_clauses)
# IO connections Output.d_id == Symlink.to_d_id Symlink.from_d_id == Input.d_id
via_symlinks_from_output = (
inferred_columns.join(DatasetSymlink, Output.dataset_id == DatasetSymlink.to_dataset_id)
.join(
Input,
DatasetSymlink.from_dataset_id == Input.dataset_id,
)
.where(*where_clauses)
)
# IO connections Input.d_id == Symlink.to_d_id Symlink.from_d_id == Output.d_id
via_symlinks_from_input = (
inferred_columns.join(DatasetSymlink, Input.dataset_id == DatasetSymlink.to_dataset_id)
.join(
Output,
DatasetSymlink.from_dataset_id == Output.dataset_id,
)
.where(*where_clauses)
# IO connections via symlinked datasets
connected = _symlink_connected_cte()
via_symlinks = (
inferred_columns.join(connected, Output.dataset_id == connected.c.original_dataset_id)
.join(Input, connected.c.dataset_id_via_symlink == Input.dataset_id)
.where(*where_clauses, connected.c.original_dataset_id != connected.c.dataset_id_via_symlink)
)

query = query.union(direct_connection, via_symlinks_from_input, via_symlinks_from_output)
query = query.union(direct_connection, via_symlinks)

return query
55 changes: 38 additions & 17 deletions data_rentgen/server/services/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from collections.abc import Collection
from dataclasses import dataclass, field
from datetime import datetime
from typing import Annotated, Literal
from typing import Annotated, Literal, NamedTuple
from uuid import UUID

from fastapi import Depends

from data_rentgen.db.models import (
Dataset,
DatasetSymlink,
DatasetSymlinkGroup,
DatasetSymlinkType,
Job,
Operation,
Run,
Expand All @@ -28,13 +29,37 @@
logger = logging.getLogger(__name__)


class SymlinkPair(NamedTuple):
from_dataset_id: int
to_dataset_id: int
type: DatasetSymlinkType


def _reconstruct_symlink_pairs(symlink_groups: list[DatasetSymlinkGroup]) -> list[SymlinkPair]:
members_by_fingerprint: dict[UUID, list[tuple[int, DatasetSymlinkType]]] = defaultdict(list)
for row in symlink_groups:
members_by_fingerprint[row.fingerprint].append((row.dataset_id, row.type))

pairs: dict[tuple[int, int], SymlinkPair] = {}
for members in members_by_fingerprint.values():
for from_dataset_id, _ in members:
for to_dataset_id, to_type in members:
if from_dataset_id != to_dataset_id:
pairs[(from_dataset_id, to_dataset_id)] = SymlinkPair(
from_dataset_id=from_dataset_id,
to_dataset_id=to_dataset_id,
type=to_type,
)
return list(pairs.values())


@dataclass
class LineageServiceResult:
jobs: dict[int, Job] = field(default_factory=dict)
runs: dict[UUID, Run] = field(default_factory=dict)
operations: dict[UUID, Operation] = field(default_factory=dict)
datasets: dict[int, Dataset] = field(default_factory=dict)
dataset_symlinks: dict[tuple[int, int], DatasetSymlink] = field(default_factory=dict)
dataset_symlinks: dict[tuple[int, int], SymlinkPair] = field(default_factory=dict)
inputs: dict[tuple[int, int, UUID | None, UUID | None], InputRow] = field(default_factory=dict)
outputs: dict[tuple[int, int, UUID | None, UUID | None, int | None], OutputRow] = field(default_factory=dict)
column_lineage: dict[tuple[int, int], list[ColumnLineageRow]] = field(default_factory=dict)
Expand Down Expand Up @@ -682,7 +707,7 @@ async def get_lineage_by_datasets( # noqa: C901

datasets_by_id = {dataset.id: dataset for dataset in datasets}

# Threat dataset symlinks like they are specified in `start_node_ids`
# Treat dataset symlinks like they are specified in `start_node_ids`
ids_to_skip = ids_to_skip or IdsToSkip()
extra_dataset_ids, dataset_symlinks = await self._resolve_dataset_ids_via_symlink(datasets_by_id, ids_to_skip)
extra_datasets = await self._uow.dataset.list_by_ids(extra_dataset_ids)
Expand Down Expand Up @@ -775,22 +800,18 @@ async def _resolve_dataset_ids_via_symlink(
self,
dataset_ids: Collection[int],
ids_to_skip: IdsToSkip | None = None,
) -> tuple[set[int], list[DatasetSymlink]]:
# For now return all symlinks regardless of direction
dataset_symlinks = await self._uow.dataset_symlink.list_by_dataset_ids(dataset_ids)
) -> tuple[set[int], list[SymlinkPair]]:
symlink_groups = await self._uow.dataset_symlink.get_symlink_groups(dataset_ids)

ids_to_skip = ids_to_skip or IdsToSkip()
dataset_ids_from_symlinks = {dataset_symlink.from_dataset_id for dataset_symlink in dataset_symlinks}
dataset_ids_to_symlinks = {dataset_symlink.to_dataset_id for dataset_symlink in dataset_symlinks}
new_dataset_ids = (
(dataset_ids_from_symlinks | dataset_ids_to_symlinks) - set(dataset_ids) - ids_to_skip.datasets
)
return new_dataset_ids, dataset_symlinks
symlink_dataset_ids = {row.dataset_id for row in symlink_groups}
new_dataset_ids = symlink_dataset_ids - set(dataset_ids) - ids_to_skip.datasets
return new_dataset_ids, _reconstruct_symlink_pairs(symlink_groups)

async def _dataset_lineage_with_operation_granularity(
self,
datasets_by_id: dict[int, Dataset],
dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink],
dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair],
direction: LineageDirectionV1,
since: datetime,
until: datetime | None,
Expand Down Expand Up @@ -931,7 +952,7 @@ async def _dataset_lineage_with_operation_granularity(
async def _dataset_lineage_with_run_granularity(
self,
datasets_by_id: dict[int, Dataset],
dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink],
dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair],
direction: LineageDirectionV1,
since: datetime,
until: datetime | None,
Expand Down Expand Up @@ -1049,7 +1070,7 @@ async def _dataset_lineage_with_run_granularity(
async def _dataset_lineage_with_job_granularity(
self,
datasets_by_id: dict[int, Dataset],
dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink],
dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair],
direction: LineageDirectionV1,
since: datetime,
until: datetime | None,
Expand Down Expand Up @@ -1142,7 +1163,7 @@ async def _dataset_lineage_with_job_granularity(
async def _dataset_lineage_with_dataset_granularity( # noqa: C901, PLR0915, PLR0912
self,
datasets_by_id: dict[int, Dataset],
dataset_symlinks_by_id: dict[tuple[int, int], DatasetSymlink],
dataset_symlinks_by_id: dict[tuple[int, int], SymlinkPair],
direction: LineageDirectionV1,
since: datetime,
until: datetime | None,
Expand Down
5 changes: 2 additions & 3 deletions data_rentgen/server/utils/lineage_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@
from uuid import UUID

from data_rentgen.db.models.dataset import Dataset
from data_rentgen.db.models.dataset_symlink import DatasetSymlink
from data_rentgen.db.models.operation import Operation
from data_rentgen.db.models.run import Run
from data_rentgen.db.repositories.column_lineage import ColumnLineageRow
from data_rentgen.db.repositories.io_dataset_relation import IODatasetRelationRow
from data_rentgen.server.services.lineage import LineageServiceResult
from data_rentgen.server.services.lineage import LineageServiceResult, SymlinkPair


def build_lineage_response(lineage: LineageServiceResult) -> LineageResponseV1:
Expand Down Expand Up @@ -108,7 +107,7 @@ def _get_operation_parent_relations(operations: dict[UUID, Operation]) -> list[L
return parents


def _get_symlink_relations(dataset_symlinks: dict[Any, DatasetSymlink]) -> list[LineageSymlinkRelationV1]:
def _get_symlink_relations(dataset_symlinks: dict[Any, SymlinkPair]) -> list[LineageSymlinkRelationV1]:
symlinks = []
for key in sorted(dataset_symlinks):
dataset_symlink = dataset_symlinks[key]
Expand Down
7 changes: 7 additions & 0 deletions tests/test_server/fixtures/factories/job_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,17 @@ async def job_dependency_chain_with_lineage_and_symlinks(
# Create datasets connected via symlinks.
left_dataset_location = await create_location(async_session)
left_output_dataset = await create_dataset(async_session, location_id=left_dataset_location.id)
left_intermediate_dataset = await create_dataset(async_session, location_id=left_dataset_location.id)
left_input_dataset = await create_dataset(async_session, location_id=left_dataset_location.id)
await make_symlink(
async_session=async_session,
from_dataset=left_output_dataset,
to_dataset=left_intermediate_dataset,
type=DatasetSymlinkType.METASTORE,
)
await make_symlink(
async_session=async_session,
from_dataset=left_intermediate_dataset,
to_dataset=left_input_dataset,
type=DatasetSymlinkType.METASTORE,
)
Expand Down
Loading