From 12f2b62e67e31c78d7f97c9fe408aa6606dd16c0 Mon Sep 17 00:00:00 2001 From: Ming Jer Lee Date: Tue, 14 Apr 2026 08:58:56 -0400 Subject: [PATCH 1/2] feat: emit tagged predicate edges for JOIN ON clause columns JOIN ON predicate columns (e.g., temporal range joins, equi-join keys) previously produced zero column-lineage edges, making them invisible in impact analysis. This adds is_join_predicate=True edges from ON-clause columns to projected output columns from the joined table, with join_condition and join_side metadata. Supports equi-joins, range/BETWEEN, function-wrapped keys, multi-join chains, and composes with Gap 4 self-read nodes for self-referencing pipelines. Closes Gap 7 from the CDC/SCD pipeline gap analysis. --- TODO.md | 7 +- src/clgraph/lineage_builder.py | 115 ++++ src/clgraph/models.py | 19 + src/clgraph/pipeline_lineage_builder.py | 4 + src/clgraph/query_parser.py | 91 +++ tests/test_join_predicate_columns.py | 641 ++++++++++++++++++++ tests/test_unqualified_column_resolution.py | 23 +- 7 files changed, 890 insertions(+), 10 deletions(-) create mode 100644 tests/test_join_predicate_columns.py diff --git a/TODO.md b/TODO.md index 0a70b3e..43d0854 100644 --- a/TODO.md +++ b/TODO.md @@ -24,7 +24,6 @@ Discovered during CDC/SCD design review (see `docs/superpowers/specs/2026-04-13- - Implemented in PR #61: self-read node detection via AST node identity, cycle-safe dependency resolution, query-scoped `{query_id}:self_read:{table}.{col}` naming, column-granular cross-query wiring, edge role/order annotations. - Design: `docs/superpowers/specs/2026-04-13-gap4-self-referencing-target-design.md` -- [ ] **Gap 7. JOIN ON predicate columns not recorded in column lineage** - - Today: JOIN ON predicates produce **zero** column-lineage edges (no handling in `lineage_builder` for ON clause columns beyond the equi-join's identity resolution). - - Symptom: point-in-time joins like `o.order_ts BETWEEN d.start_time AND d.end_time` leave `start_time`/`end_time` invisible as influences on downstream columns. - - Needs its own design doc — new edge semantic for "predicate-conditional" columns. +- [x] **Gap 7. JOIN ON predicate columns not recorded in column lineage** + - Implemented: tagged `is_join_predicate=True` edges from ON-clause columns to right-side projected output columns. Supports equi-joins, range/BETWEEN, function-wrapped, multi-join chains, and Gap 4 self-read interaction. + - Design: `docs/superpowers/specs/2026-04-13-gap7-join-predicate-columns-design.md` diff --git a/src/clgraph/lineage_builder.py b/src/clgraph/lineage_builder.py index aaf16ff..1e190e2 100644 --- a/src/clgraph/lineage_builder.py +++ b/src/clgraph/lineage_builder.py @@ -173,6 +173,10 @@ def _process_unit(self, unit: QueryUnit): if unit.window_info: self._create_window_function_edges(unit, output_cols) + # 9. Create join predicate edges + if unit.join_predicates: + self._create_join_predicate_edges(unit, output_cols) + def _create_window_function_edges(self, unit: QueryUnit, output_cols: List[Dict]): """ Create edges for columns used in window functions. @@ -373,6 +377,117 @@ def _resolve_grouping_column(self, unit: QueryUnit, col_ref: str) -> Optional[Co # Reuse the QUALIFY column resolution logic return self._resolve_qualify_column(unit, col_ref) + def _create_join_predicate_edges(self, unit: QueryUnit, output_cols: List[Dict]): + """ + Create edges for columns used in JOIN ON clauses. + + For each JOIN predicate, identifies output columns sourced from the right-side + table and creates predicate edges from each ON-clause column to those output columns. + + Args: + unit: The query unit with join_predicates + output_cols: The output columns of this unit + """ + import logging + + logger = logging.getLogger("clgraph.lineage_builder") + + for info in unit.join_predicates: + right_table = info.right_table + if not right_table: + continue + + # Identify output columns sourced from the right-side table + right_output_nodes: List[ColumnNode] = [] + for col_info in output_cols: + if col_info.get("is_star"): + continue + node_key = get_node_key(unit, col_info) + if node_key not in self.lineage_graph.nodes: + continue + output_node = self.lineage_graph.nodes[node_key] + + # Check if this output column's source expression references the right table + ast_node = col_info.get("ast_node") + if ast_node is None: + continue + + is_right_sourced = False + for col_node in ast_node.find_all(exp.Column): + col_table = col_node.table if col_node.table else None + if col_table and col_table == right_table: + is_right_sourced = True + break + + if is_right_sourced: + right_output_nodes.append(output_node) + + if not right_output_nodes: + continue + + # For each column in the predicate, resolve to a source node and emit edges + for table_ref, col_name in info.columns: + source_node = self._resolve_join_predicate_column(unit, table_ref, col_name) + if not source_node: + logger.debug( + "Could not resolve join predicate column: %s.%s", table_ref, col_name + ) + continue + + # Determine join_side based on whether this column is from the right table + join_side = "right" if table_ref == right_table else "left" + + for output_node in right_output_nodes: + edge = ColumnEdge( + from_node=source_node, + to_node=output_node, + edge_type="join_predicate", + transformation="join_predicate", + context="JOIN ON", + expression=info.condition_sql, + is_join_predicate=True, + join_condition=info.condition_sql, + join_side=join_side, + ) + self.lineage_graph.add_edge(edge) + + def _resolve_join_predicate_column( + self, unit: QueryUnit, table_ref: Optional[str], col_name: str + ) -> Optional[ColumnNode]: + """ + Resolve a column reference from a JOIN ON clause to a ColumnNode. + + Args: + unit: The query unit + table_ref: Table alias/name or None if unqualified + col_name: Column name + + Returns: + ColumnNode or None if not found + """ + # Try to resolve as a source unit (CTE/subquery) + source_unit = self._resolve_source_unit(unit, table_ref) if table_ref else None + if source_unit: + return self._find_column_in_unit(source_unit, col_name) + + # Try as base table + base_table = self._resolve_base_table_name(unit, table_ref) if table_ref else None + if base_table: + return find_or_create_table_column_node(self.lineage_graph, base_table, col_name) + + # Try without table qualifier - infer from dependencies + if not table_ref and unit.depends_on_tables: + for table in unit.depends_on_tables: + node = find_or_create_table_column_node(self.lineage_graph, table, col_name) + if node: + return node + + # Fallback: use table_ref directly if provided + if table_ref: + return find_or_create_table_column_node(self.lineage_graph, table_ref, col_name) + + return None + def _create_qualify_edges(self, unit: QueryUnit, output_cols: List[Dict]): """ Create edges for columns used in QUALIFY clause. diff --git a/src/clgraph/models.py b/src/clgraph/models.py index fbb1a56..f096de8 100644 --- a/src/clgraph/models.py +++ b/src/clgraph/models.py @@ -177,6 +177,16 @@ def __repr__(self) -> str: return f"ValuesInfo({self.alias}({cols}), {self.row_count} rows)" +@dataclass +class JoinPredicateInfo: + """Information about a JOIN ON clause predicate for column lineage tracking.""" + + condition_sql: str # Raw SQL of the ON clause + columns: List[Tuple[Optional[str], str]] # (table_ref, col_name) pairs + join_type: str # "inner", "left", "right", "full", "cross" + right_table: Optional[str] # Name/alias of the joined (right-side) table + + @dataclass class QueryUnit: """ @@ -271,6 +281,10 @@ class QueryUnit: # Example: {'t': ValuesInfo(alias='t', column_names=['id', 'name'], row_count=2)} values_sources: Dict[str, "ValuesInfo"] = field(default_factory=dict) + # JOIN predicate metadata + # Stores info about JOIN ON clause columns for predicate lineage edges + join_predicates: List["JoinPredicateInfo"] = field(default_factory=list) + # Metadata depth: int = 0 # Nesting depth (0 = main query) order: int = 0 # Topological order for CTEs @@ -626,6 +640,11 @@ class ColumnEdge: tvf_info: Optional["TVFInfo"] = None # Full TVF specification is_tvf_output: bool = False # True if this edge is from a TVF output + # ─── JOIN Predicate Metadata ─── + is_join_predicate: bool = False # True if this edge is from a JOIN ON clause + join_condition: Optional[str] = None # Raw SQL of the ON clause + join_side: Optional[str] = None # "left" or "right" (which side of the join this column is on) + # ─── Self-Reference / Pipeline Ordering Metadata ─── statement_order: Optional[int] = None # Topological sort index of the query edge_role: Optional[str] = None # "prior_state_read", "cross_query_self_ref", or None diff --git a/src/clgraph/pipeline_lineage_builder.py b/src/clgraph/pipeline_lineage_builder.py index 34a92d4..55191af 100644 --- a/src/clgraph/pipeline_lineage_builder.py +++ b/src/clgraph/pipeline_lineage_builder.py @@ -466,6 +466,10 @@ def _add_query_edges( window_frame_end=getattr(edge, "window_frame_end", None), window_order_direction=getattr(edge, "window_order_direction", None), window_order_nulls=getattr(edge, "window_order_nulls", None), + # Preserve JOIN predicate metadata + is_join_predicate=getattr(edge, "is_join_predicate", False), + join_condition=getattr(edge, "join_condition", None), + join_side=getattr(edge, "join_side", None), # Preserve complex aggregate metadata aggregate_spec=getattr(edge, "aggregate_spec", None), ) diff --git a/src/clgraph/query_parser.py b/src/clgraph/query_parser.py index 6d733e5..9c23860 100644 --- a/src/clgraph/query_parser.py +++ b/src/clgraph/query_parser.py @@ -11,6 +11,7 @@ from sqlglot import exp from .models import ( + JoinPredicateInfo, QueryUnit, QueryUnitGraph, QueryUnitType, @@ -177,6 +178,22 @@ def _parse_select_unit( for join in joins: self._parse_from_sources(join, unit, depth) + # 3b. Extract JOIN ON predicate columns for lineage tracking + for join in joins: + on_clause = join.args.get("on") + if on_clause: + cols = self._extract_join_predicate_columns(on_clause) + join_type = self._get_join_type(join) + right_table = self._get_join_right_table(join, unit) + unit.join_predicates.append( + JoinPredicateInfo( + condition_sql=on_clause.sql(), + columns=cols, + join_type=join_type, + right_table=right_table, + ) + ) + # 4. Parse WHERE clause (may contain subqueries) where_clause = select_node.args.get("where") if where_clause: @@ -1460,6 +1477,80 @@ def register_preceding_table(name: str): if alias and alias not in parent_unit.lateral_sources: process_lateral_subquery(node, parent_unit, preceding_tables) + def _extract_join_predicate_columns( + self, on_clause: exp.Expression + ) -> List[Tuple[Optional[str], str]]: + """ + Extract column references from a JOIN ON clause expression. + + Walks the expression tree to find all exp.Column nodes and returns + (table_ref, col_name) pairs. Literals are ignored (they are not columns). + + Args: + on_clause: The ON clause expression from a JOIN + + Returns: + List of (table_ref_or_None, column_name) tuples + """ + columns: List[Tuple[Optional[str], str]] = [] + for node in on_clause.walk(): + if isinstance(node, exp.Column): + table_ref = node.table if node.table else None + col_name = node.name + columns.append((table_ref, col_name)) + return columns + + def _get_join_type(self, join: exp.Join) -> str: + """ + Extract the join type string from a sqlglot Join node. + + Args: + join: The sqlglot Join expression + + Returns: + Join type string like "inner", "left", "right", "full", "cross" + """ + side = join.side + kind = join.kind + + if side: + return side.lower() + if kind: + return kind.lower() + return "inner" + + def _get_join_right_table(self, join: exp.Join, unit: QueryUnit) -> Optional[str]: + """ + Extract the right-side table name or alias from a JOIN clause. + + The join's `this` contains the table being joined (the right side). + Uses the table's alias if present, otherwise the table name. + + Args: + join: The sqlglot Join expression + unit: The QueryUnit (for alias_mapping lookup) + + Returns: + The alias or name of the right-side table, or None + """ + table_node = join.this + + # Handle subquery case + if isinstance(table_node, exp.Subquery): + alias = table_node.alias + if alias: + return str(alias) + return None + + # Handle table case + if isinstance(table_node, exp.Table): + # Prefer alias over table name + if hasattr(table_node, "alias") and table_node.alias: + return str(table_node.alias) + return table_node.name + + return None + def _parse_where_subqueries( self, where_node: exp.Expression, parent_unit: QueryUnit, depth: int ): diff --git a/tests/test_join_predicate_columns.py b/tests/test_join_predicate_columns.py new file mode 100644 index 0000000..91544df --- /dev/null +++ b/tests/test_join_predicate_columns.py @@ -0,0 +1,641 @@ +""" +Test suite for Gap 7: JOIN ON Predicate Columns in Column Lineage. + +Tests cover: +- CDC/SCD2 point-in-time join (BETWEEN) +- Band join (non-CDC) +- Function-based join (UPPER) +- Multi-join chain +- Existing equi-join tests still pass +- Impact analysis opt-in/opt-out +- Dialect consistency +- Self-referencing query with JOIN predicates (Gap 4 interaction) +- Multi-statement SCD2 pipeline with JOIN predicates (Gap 4 + Gap 7) +- Unqualified predicate column handling + +Total: 10 test cases +""" + +import pytest + +from clgraph import Pipeline, RecursiveLineageBuilder, SQLColumnTracer + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _edges_dict(graph): + """Build a dict keyed by (from_full_name, to_full_name) -> edge.""" + return {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges} + + +def _predicate_edges(graph): + """Return only edges with is_join_predicate=True.""" + return [e for e in graph.edges if e.is_join_predicate] + + +def _predicate_edges_to(graph, target_full_name): + """Return predicate edges targeting a specific output column.""" + return [ + e for e in graph.edges if e.is_join_predicate and e.to_node.full_name == target_full_name + ] + + +def _predicate_sources_to(graph, target_full_name): + """Return set of from_node.full_name for predicate edges to a target.""" + return {e.from_node.full_name for e in _predicate_edges_to(graph, target_full_name)} + + +# ============================================================================ +# Test 1: CDC/SCD2 point-in-time join +# ============================================================================ + + +class TestCDCSCD2PointInTimeJoin: + """Test 1: CDC/SCD2 BETWEEN join produces predicate edges.""" + + SQL = """ + SELECT o.order_id, o.customer_id, o.order_ts, o.amount, + d.city AS customer_city_at_order + FROM raw_orders o + LEFT JOIN dim_customer d + ON o.customer_id = d.id + AND o.order_ts BETWEEN d.start_time AND d.end_time + """ + + def test_predicate_edges_from_dim_customer_start_time(self): + """dim_customer.start_time has predicate edge to output.customer_city_at_order.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sources = _predicate_sources_to(graph, "output.customer_city_at_order") + assert "dim_customer.start_time" in sources + + def test_predicate_edges_from_dim_customer_end_time(self): + """dim_customer.end_time has predicate edge to output.customer_city_at_order.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sources = _predicate_sources_to(graph, "output.customer_city_at_order") + assert "dim_customer.end_time" in sources + + def test_predicate_edges_from_raw_orders_order_ts(self): + """raw_orders.order_ts has predicate edge to output.customer_city_at_order.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sources = _predicate_sources_to(graph, "output.customer_city_at_order") + assert "raw_orders.order_ts" in sources + + def test_predicate_edges_from_raw_orders_customer_id(self): + """raw_orders.customer_id has predicate edge to output.customer_city_at_order.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sources = _predicate_sources_to(graph, "output.customer_city_at_order") + assert "raw_orders.customer_id" in sources + + def test_predicate_edges_from_dim_customer_id(self): + """dim_customer.id has predicate edge to output.customer_city_at_order.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sources = _predicate_sources_to(graph, "output.customer_city_at_order") + assert "dim_customer.id" in sources + + def test_value_edge_not_marked_as_predicate(self): + """dim_customer.city -> output.customer_city_at_order is NOT is_join_predicate.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + edges = _edges_dict(graph) + value_edge = edges.get(("dim_customer.city", "output.customer_city_at_order")) + assert value_edge is not None, "Value edge dim_customer.city -> output should exist" + assert not value_edge.is_join_predicate, "Value edge should not be marked as join predicate" + + def test_all_predicate_edges_have_join_condition(self): + """All predicate edges carry join_condition metadata.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + pred_edges = _predicate_edges_to(graph, "output.customer_city_at_order") + assert len(pred_edges) >= 5, f"Expected at least 5 predicate edges, got {len(pred_edges)}" + for edge in pred_edges: + assert edge.join_condition is not None, ( + f"Predicate edge from {edge.from_node.full_name} missing join_condition" + ) + assert edge.edge_type == "join_predicate", ( + f"Predicate edge should have edge_type='join_predicate', got '{edge.edge_type}'" + ) + + def test_predicate_edges_have_join_side(self): + """Predicate edges carry join_side ('left' or 'right').""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + pred_edges = _predicate_edges_to(graph, "output.customer_city_at_order") + sides = {e.from_node.full_name: e.join_side for e in pred_edges} + + # Right-side columns (dim_customer) + assert sides.get("dim_customer.start_time") == "right" + assert sides.get("dim_customer.end_time") == "right" + assert sides.get("dim_customer.id") == "right" + + # Left-side columns (raw_orders) + assert sides.get("raw_orders.customer_id") == "left" + assert sides.get("raw_orders.order_ts") == "left" + + +# ============================================================================ +# Test 2: Band join (non-CDC) +# ============================================================================ + + +class TestBandJoin: + """Test 2: Non-CDC band join with BETWEEN and equi-join produces predicate edges.""" + + SQL = """ + SELECT e.event_id, e.event_ts, s.sensor_id, s.reading + FROM events e + INNER JOIN sensor_data s + ON e.event_ts BETWEEN s.reading_ts - INTERVAL '5' MINUTE AND s.reading_ts + INTERVAL '5' MINUTE + AND e.location_id = s.location_id + """ + + def test_predicate_edges_to_right_side_columns(self): + """Predicate edges target right-side projected columns (sensor_data outputs).""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + # Check predicate edges to output.sensor_id + sensor_id_sources = _predicate_sources_to(graph, "output.sensor_id") + reading_sources = _predicate_sources_to(graph, "output.reading") + + # sensor_data.reading_ts should have predicate edges to right-side outputs + assert ( + "sensor_data.reading_ts" in sensor_id_sources + or "sensor_data.reading_ts" in reading_sources + ), "sensor_data.reading_ts should have predicate edge to a right-side output" + + # events.event_ts should have predicate edges to right-side outputs + assert "events.event_ts" in sensor_id_sources or "events.event_ts" in reading_sources, ( + "events.event_ts should have predicate edge to a right-side output" + ) + + def test_location_id_predicate_edges(self): + """Both sides of equi-join in ON clause have predicate edges.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sensor_id_sources = _predicate_sources_to(graph, "output.sensor_id") + reading_sources = _predicate_sources_to(graph, "output.reading") + + all_pred_sources = sensor_id_sources | reading_sources + + assert "events.location_id" in all_pred_sources, ( + "events.location_id should have predicate edge to right-side output" + ) + assert "sensor_data.location_id" in all_pred_sources, ( + "sensor_data.location_id should have predicate edge to right-side output" + ) + + +# ============================================================================ +# Test 3: Function-based join +# ============================================================================ + + +class TestFunctionBasedJoin: + """Test 3: JOIN with UPPER() function wrapping produces predicate edges.""" + + SQL = """ + SELECT a.id, b.name + FROM table_a a + INNER JOIN table_b b ON UPPER(a.key) = UPPER(b.key) + """ + + def test_function_wrapped_columns_produce_predicate_edges(self): + """table_a.key and table_b.key have predicate edges to output.name.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + name_sources = _predicate_sources_to(graph, "output.name") + + assert "table_a.key" in name_sources, ( + "table_a.key should have predicate edge to output.name" + ) + assert "table_b.key" in name_sources, ( + "table_b.key should have predicate edge to output.name" + ) + + def test_no_predicate_edges_to_left_side_output(self): + """Predicate edges should not target left-side output (output.id).""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + id_pred_sources = _predicate_sources_to(graph, "output.id") + assert len(id_pred_sources) == 0, ( + f"No predicate edges expected to output.id, got sources: {id_pred_sources}" + ) + + +# ============================================================================ +# Test 4: Multi-join chain +# ============================================================================ + + +class TestMultiJoinChain: + """Test 4: Multi-join chain has scoped predicate edges per JOIN.""" + + SQL = """ + SELECT a.id, b.val, c.label + FROM table_a a + INNER JOIN table_b b ON a.id = b.a_id + INNER JOIN table_c c ON b.id = c.b_id AND b.category = c.category + """ + + def test_first_join_predicate_edges_to_output_val(self): + """First join: a.id, b.a_id predicate edges to output.val.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + val_sources = _predicate_sources_to(graph, "output.val") + + assert "table_a.id" in val_sources, ( + "table_a.id should have predicate edge to output.val (first join)" + ) + assert "table_b.a_id" in val_sources, ( + "table_b.a_id should have predicate edge to output.val (first join)" + ) + + def test_second_join_predicate_edges_to_output_label(self): + """Second join: b.id, c.b_id, b.category, c.category predicate edges to output.label.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + label_sources = _predicate_sources_to(graph, "output.label") + + assert "table_b.id" in label_sources, ( + "table_b.id should have predicate edge to output.label (second join)" + ) + assert "table_c.b_id" in label_sources, ( + "table_c.b_id should have predicate edge to output.label (second join)" + ) + assert "table_b.category" in label_sources, ( + "table_b.category should have predicate edge to output.label (second join)" + ) + assert "table_c.category" in label_sources, ( + "table_c.category should have predicate edge to output.label (second join)" + ) + + def test_first_join_predicates_do_not_target_output_label(self): + """First join's predicates (a.id, b.a_id) should NOT have edges to output.label.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + label_sources = _predicate_sources_to(graph, "output.label") + + assert "table_a.id" not in label_sources, ( + "table_a.id (first join) should NOT have predicate edge to output.label" + ) + assert "table_b.a_id" not in label_sources, ( + "table_b.a_id (first join) should NOT have predicate edge to output.label" + ) + + +# ============================================================================ +# Test 5: Existing equi-join tests still pass +# ============================================================================ + + +class TestExistingEquiJoinTestsPass: + """Test 5: Verify existing join tests are not broken by predicate edge additions.""" + + def test_existing_join_types_tests_pass(self): + """ + Verify that existing test_join_types.py tests still pass. + + This is a meta-test: the actual verification is done by running + `uv run pytest tests/test_join_types.py -q` during CI. + Here we verify a representative case: simple INNER JOIN value edges + are unchanged after Gap 7 predicate edges are added. + """ + sql = """ + SELECT u.id, u.name, o.order_id, o.amount + FROM users u + INNER JOIN orders o ON u.id = o.user_id + """ + builder = RecursiveLineageBuilder(sql, dialect="bigquery") + graph = builder.build() + + edges = _edges_dict(graph) + + # Existing value edges are unchanged + assert ("users.id", "output.id") in edges + assert ("users.name", "output.name") in edges + assert ("orders.order_id", "output.order_id") in edges + assert ("orders.amount", "output.amount") in edges + + # Value edges are NOT marked as join predicates + assert not edges[("users.id", "output.id")].is_join_predicate + assert not edges[("users.name", "output.name")].is_join_predicate + + +# ============================================================================ +# Test 6: Impact analysis opt-in/opt-out +# ============================================================================ + + +class TestImpactAnalysisOptInOut: + """Test 6: Forward lineage includes predicate columns; flag is accessible.""" + + SQL = """ + SELECT o.order_id, o.customer_id, o.order_ts, o.amount, + d.city AS customer_city_at_order + FROM raw_orders o + LEFT JOIN dim_customer d + ON o.customer_id = d.id + AND o.order_ts BETWEEN d.start_time AND d.end_time + """ + + def test_forward_lineage_includes_predicate_column(self): + """Forward lineage from dim_customer.start_time includes customer_city_at_order.""" + tracer = SQLColumnTracer(self.SQL, dialect="bigquery") + forward = tracer.get_forward_lineage(["dim_customer.start_time"]) + + assert "customer_city_at_order" in forward["impacted_outputs"], ( + f"customer_city_at_order should be impacted by dim_customer.start_time; " + f"got {forward['impacted_outputs']}" + ) + + def test_is_join_predicate_flag_accessible(self): + """The is_join_predicate flag is accessible on edges from the graph.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + pred_edges = _predicate_edges(graph) + assert len(pred_edges) > 0, "Should have predicate edges" + + # Verify we can filter for value-only edges + value_edges = [e for e in graph.edges if not e.is_join_predicate] + assert len(value_edges) > 0, "Should have value edges" + assert len(value_edges) < len(graph.edges), ( + "Value edges should be fewer than total edges (some are predicate)" + ) + + +# ============================================================================ +# Test 7: Dialect consistency +# ============================================================================ + + +class TestDialectConsistency: + """Test 7: CDC BETWEEN join produces identical predicate edges across dialects.""" + + SQL = """ + SELECT o.order_id, o.customer_id, o.order_ts, o.amount, + d.city AS customer_city_at_order + FROM raw_orders o + LEFT JOIN dim_customer d + ON o.customer_id = d.id + AND o.order_ts BETWEEN d.start_time AND d.end_time + """ + + DIALECTS = ["bigquery", "postgres", "snowflake", "databricks"] + + @pytest.mark.parametrize("dialect", DIALECTS) + def test_predicate_edges_consistent_across_dialects(self, dialect): + """Predicate edges from CDC BETWEEN join are consistent across dialects.""" + builder = RecursiveLineageBuilder(self.SQL, dialect=dialect) + graph = builder.build() + + sources = _predicate_sources_to(graph, "output.customer_city_at_order") + + assert "dim_customer.start_time" in sources, ( + f"[{dialect}] dim_customer.start_time should have predicate edge" + ) + assert "dim_customer.end_time" in sources, ( + f"[{dialect}] dim_customer.end_time should have predicate edge" + ) + assert "raw_orders.order_ts" in sources, ( + f"[{dialect}] raw_orders.order_ts should have predicate edge" + ) + assert "raw_orders.customer_id" in sources, ( + f"[{dialect}] raw_orders.customer_id should have predicate edge" + ) + assert "dim_customer.id" in sources, ( + f"[{dialect}] dim_customer.id should have predicate edge" + ) + + # Value edge exists and is not a predicate + edges = _edges_dict(graph) + value_edge = edges.get(("dim_customer.city", "output.customer_city_at_order")) + assert value_edge is not None, f"[{dialect}] Value edge should exist" + assert not value_edge.is_join_predicate, f"[{dialect}] Value edge should not be predicate" + + +# ============================================================================ +# Test 8: Self-referencing query with JOIN predicates (Gap 4 interaction) +# ============================================================================ + + +class TestSelfRefWithJoinPredicates: + """Test 8: Single-query self-referencing INSERT with JOIN predicates.""" + + SQL = """\ +INSERT INTO dim_customer +SELECT s.id, s.name, s.city, s.email, + COALESCE(t.is_active, 'Y') AS is_active +FROM staging s +LEFT JOIN dim_customer t + ON s.id = t.id AND t.is_active = 'Y' +WHERE t.id IS NULL OR (t.name <> s.name OR t.city <> s.city) +""" + + def test_self_read_nodes_exist(self): + """Gap 4: self-read nodes should exist for dim_customer.""" + pipeline = Pipeline( + queries=[("q0", self.SQL)], + dialect="bigquery", + ) + + self_read_nodes = [ + col for col in pipeline.columns.values() if ":self_read:dim_customer." in col.full_name + ] + assert len(self_read_nodes) > 0, "Self-read nodes should exist for dim_customer" + + def test_predicate_edges_from_self_read_nodes(self): + """Gap 7: predicate edges should originate from self-read nodes, not physical nodes.""" + pipeline = Pipeline( + queries=[("q0", self.SQL)], + dialect="bigquery", + ) + + pred_edges = [e for e in pipeline.edges if e.is_join_predicate] + + # There should be some predicate edges + assert len(pred_edges) > 0, "Should have predicate edges from JOIN ON clause" + + # Predicate edges from the dim_customer side should reference self-read nodes + dim_pred_edges = [e for e in pred_edges if "dim_customer" in e.from_node.full_name] + for edge in dim_pred_edges: + assert ( + ":self_read:" in edge.from_node.full_name or edge.from_node.node_type == "self_read" + ), ( + f"Predicate edge from dim_customer should originate from self-read node, " + f"got {edge.from_node.full_name} (node_type={edge.from_node.node_type})" + ) + + def test_no_predicate_edge_from_where_only_columns(self): + """WHERE-only columns (t.name, t.city in WHERE but not ON) should NOT have predicate edges.""" + pipeline = Pipeline( + queries=[("q0", self.SQL)], + dialect="bigquery", + ) + + pred_edges = [e for e in pipeline.edges if e.is_join_predicate] + {e.from_node.column_name for e in pred_edges} + + # t.name and t.city appear in WHERE but not in ON clause + # They should NOT have join_predicate edges + # (Note: they may have other edge types, but not is_join_predicate) + on_clause_col_names = {"id", "is_active"} # columns in the ON clause from dim_customer + for edge in pred_edges: + if "dim_customer" in edge.from_node.full_name: + assert edge.from_node.column_name in on_clause_col_names, ( + f"Predicate edge from dim_customer column '{edge.from_node.column_name}' " + f"should only come from ON-clause columns {on_clause_col_names}" + ) + + +# ============================================================================ +# Test 9: Multi-statement SCD2 pipeline with JOIN predicates (Gap 4 + Gap 7) +# ============================================================================ + + +SCD2_MERGE_SQL = """\ +MERGE INTO dim_customer t +USING staging_customer_latest s ON t.id = s.id AND t.is_active = 'Y' +WHEN MATCHED AND (t.name <> s.name OR t.city <> s.city) THEN + UPDATE SET t.end_time = current_timestamp(), t.is_active = 'N' +""" + +SCD2_INSERT_SQL = """\ +INSERT INTO dim_customer +SELECT s.id, s.name, s.city, s.email, + current_timestamp() AS start_time, + TIMESTAMP '9999-12-31 00:00:00' AS end_time, + COALESCE(t.is_active, 'Y') AS is_active +FROM staging_customer_latest s +LEFT JOIN dim_customer t + ON s.id = t.id AND t.is_active = 'Y' +WHERE t.id IS NULL OR (t.name <> s.name OR t.city <> s.city) +""" + + +class TestMultiStatementSCD2Pipeline: + """Test 9: Two-step SCD2 pipeline (MERGE + INSERT) with Gap 4 + Gap 7 interaction.""" + + @pytest.fixture + def scd2_pipeline(self): + return Pipeline( + queries=[ + ("step1_merge", SCD2_MERGE_SQL), + ("step2_insert", SCD2_INSERT_SQL), + ], + dialect="bigquery", + ) + + def test_step2_on_clause_predicate_edges_exist(self, scd2_pipeline): + """Step 2's ON-clause predicate columns produce predicate edges.""" + pred_edges = [e for e in scd2_pipeline.edges if e.is_join_predicate] + assert len(pred_edges) > 0, "Step 2 JOIN should produce predicate edges" + + def test_step2_predicate_edges_from_self_read_nodes(self, scd2_pipeline): + """Step 2's predicate edges from dim_customer should use self-read nodes.""" + pred_edges = [e for e in scd2_pipeline.edges if e.is_join_predicate] + + dim_pred_edges = [e for e in pred_edges if "dim_customer" in e.from_node.full_name] + for edge in dim_pred_edges: + assert ( + ":self_read:" in edge.from_node.full_name or edge.from_node.node_type == "self_read" + ), ( + f"Step 2 predicate edge from dim_customer should be from self-read node, " + f"got {edge.from_node.full_name}" + ) + + def test_cross_query_edges_exist(self, scd2_pipeline): + """Gap 4: cross-query edges from Step 1 output to Step 2 self-read exist.""" + cross_query_edges = [ + e for e in scd2_pipeline.edges if e.edge_role == "cross_query_self_ref" + ] + assert len(cross_query_edges) > 0, "Cross-query self-ref edges should exist" + + def test_self_read_nodes_exist(self, scd2_pipeline): + """Gap 4: self-read nodes for dim_customer should exist.""" + self_read_nodes = [ + col + for col in scd2_pipeline.columns.values() + if ":self_read:dim_customer." in col.full_name + ] + assert len(self_read_nodes) > 0, "Self-read nodes for dim_customer should exist" + + +# ============================================================================ +# Test 10: Unqualified predicate column emits warning +# ============================================================================ + + +class TestUnqualifiedPredicateColumn: + """Test 10: Unqualified output column with qualified ON-clause columns.""" + + SQL = """ + SELECT a.id, name + FROM table_a a + INNER JOIN table_b b ON a.id = b.id + """ + + def test_unqualified_output_prevents_predicate_edge_targeting(self): + """Unqualified output 'name' cannot be traced to a specific table side. + + When the only right-side projected column is unqualified, the implementation + cannot determine it is sourced from the right table (table_b), so no + predicate edges are emitted. This is expected: ambiguous columns do not + produce predicate edges. The validation system emits a warning instead. + """ + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + pred_edges = _predicate_edges(graph) + + # No predicate edges are expected because the only right-side output + # column ('name') is unqualified and cannot be attributed to table_b. + assert len(pred_edges) == 0, ( + f"No predicate edges expected when right-side output is unqualified; " + f"got {len(pred_edges)} edges" + ) + + def test_unqualified_name_column_still_resolves(self): + """The unqualified 'name' column should still produce a value edge.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + # output.name should exist regardless of qualification ambiguity + assert "output.name" in graph.nodes, "output.name should exist in graph" + + # It should have at least one value (non-predicate) edge + value_edges_to_name = [ + e + for e in graph.edges + if e.to_node.full_name == "output.name" and not e.is_join_predicate + ] + # The unqualified 'name' may resolve to table_a.name or table_b.name + # depending on the implementation. Either way, a value edge should exist. + assert len(value_edges_to_name) >= 1, ( + "Unqualified 'name' should have at least one value edge" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_unqualified_column_resolution.py b/tests/test_unqualified_column_resolution.py index a0ff60c..7068f69 100644 --- a/tests/test_unqualified_column_resolution.py +++ b/tests/test_unqualified_column_resolution.py @@ -232,9 +232,12 @@ def test_three_layer_pipeline_date_trunc(self): pipeline = Pipeline.from_sql_list(queries, dialect="bigquery") - # Check that month has lineage to order_date + # Check that month has lineage to order_date (filter out join predicate edges) month_edges = [ - e for e in pipeline.edges if e.to_node.full_name == "reports.monthly_revenue.month" + e + for e in pipeline.edges + if e.to_node.full_name == "reports.monthly_revenue.month" + and not getattr(e, "is_join_predicate", False) ] assert len(month_edges) == 1 assert month_edges[0].from_node.column_name == "order_date" @@ -291,9 +294,12 @@ def test_ambiguous_column_resolved_correctly(self): assert len(value_a_edges) == 1 assert value_a_edges[0].from_node.table_name == "staging.table_a" - # value_b should come from table_b + # value_b should come from table_b (filter out join predicate edges) value_b_edges = [ - e for e in pipeline.edges if e.to_node.full_name == "reports.combined.value_b" + e + for e in pipeline.edges + if e.to_node.full_name == "reports.combined.value_b" + and not getattr(e, "is_join_predicate", False) ] assert len(value_b_edges) == 1 assert value_b_edges[0].from_node.table_name == "staging.table_b" @@ -329,8 +335,13 @@ def test_column_in_both_tables(self): assert len(amount_edges) == 1 assert amount_edges[0].from_node.table_name == "staging.orders" - # name should come from users - name_edges = [e for e in pipeline.edges if e.to_node.full_name == "reports.summary.name"] + # name should come from users (filter out join predicate edges) + name_edges = [ + e + for e in pipeline.edges + if e.to_node.full_name == "reports.summary.name" + and not getattr(e, "is_join_predicate", False) + ] assert len(name_edges) == 1 assert name_edges[0].from_node.table_name == "staging.users" From 166a57cf3c02973f23cf3772eaa042a5cf851242 Mon Sep 17 00:00:00 2001 From: Ming Jer Lee Date: Tue, 14 Apr 2026 10:10:16 -0400 Subject: [PATCH 2/2] docs: add example notebooks for self-referencing targets and JOIN predicates Two new example notebooks demonstrating Gap 4 and Gap 7 features: - self_referencing_lineage.ipynb: Single-statement self-ref, SCD2 MERGE+INSERT, impact analysis through self-read chains, edge annotations - join_predicate_lineage.ipynb: Equi-join predicates, point-in-time BETWEEN joins, multi-join chain scoping, impact analysis with predicate filtering --- examples/join_predicate_lineage.ipynb | 244 +++++++++++++++++++++ examples/self_referencing_lineage.ipynb | 268 ++++++++++++++++++++++++ 2 files changed, 512 insertions(+) create mode 100644 examples/join_predicate_lineage.ipynb create mode 100644 examples/self_referencing_lineage.ipynb diff --git a/examples/join_predicate_lineage.ipynb b/examples/join_predicate_lineage.ipynb new file mode 100644 index 0000000..33770cd --- /dev/null +++ b/examples/join_predicate_lineage.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a1b2c3d4", + "metadata": {}, + "source": [ + "# JOIN Predicate Column Lineage\n", + "\n", + "**Example: Tracking JOIN ON Predicate Columns in Column Lineage (Gap 7)**\n", + "\n", + "\n", + "This example demonstrates how clgraph tracks JOIN ON predicate columns as\n", + "lineage edges. Before Gap 7, only value-flow columns (columns in SELECT)\n", + "appeared in the lineage graph. Now, columns used in JOIN ON clauses are\n", + "tracked as predicate edges, making previously invisible dependencies\n", + "visible for impact analysis.\n", + "\n", + "Key features demonstrated:\n", + "1. Basic equi-join predicate edges with metadata\n", + "2. Point-in-time / range join (BETWEEN) with 5 predicate columns\n", + "3. Multi-join chain with per-join scoped predicate edges\n", + "4. Impact analysis using predicate edges with SQLColumnTracer" + ] + }, + { + "cell_type": "markdown", + "id": "b2c3d4e5", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3d4e5f6", + "metadata": {}, + "outputs": [], + "source": [ + "from clgraph import Pipeline, RecursiveLineageBuilder, SQLColumnTracer\n", + "\n", + "\n", + "def predicate_edges(graph):\n", + " \"\"\"Return only edges where is_join_predicate is True.\"\"\"\n", + " return [e for e in graph.edges if e.is_join_predicate]\n", + "\n", + "\n", + "def predicate_edges_to(graph, target):\n", + " \"\"\"Return predicate edges targeting a specific output column.\"\"\"\n", + " return [e for e in graph.edges if e.is_join_predicate and e.to_node.full_name == target]\n", + "\n", + "\n", + "# ============================================================\n", + "# Example 1: Basic Equi-Join Predicate Edges\n", + "# ============================================================\n", + "print(\"=\" * 60)\n", + "print(\"Example 1: Basic Equi-Join Predicate Edges\")\n", + "print(\"=\" * 60)\n", + "\n", + "sql_1 = \"\"\"\n", + "SELECT o.order_id, o.amount, d.city AS customer_city\n", + "FROM raw_orders o\n", + "LEFT JOIN dim_customer d ON o.customer_id = d.id\n", + "\"\"\"\n", + "\n", + "builder_1 = RecursiveLineageBuilder(sql_1, dialect=\"bigquery\")\n", + "graph_1 = builder_1.build()\n", + "\n", + "print(f\"\\nQuery:{sql_1}\")\n", + "print(\"1a. Value edges (standard lineage):\")\n", + "for edge in graph_1.edges:\n", + " if not edge.is_join_predicate:\n", + " print(f\" {edge.from_node.full_name} -> {edge.to_node.full_name}\")\n", + "\n", + "print(\"\\n1b. Predicate edges (NEW \\u2014 from JOIN ON clause):\")\n", + "for edge in predicate_edges(graph_1):\n", + " print(f\" {edge.from_node.full_name} -> {edge.to_node.full_name}\")\n", + " print(f\" \\u2022 edge_type = {edge.edge_type}\")\n", + " print(f\" \\u2022 join_side = {edge.join_side}\")\n", + " print(f\" \\u2022 join_condition = {edge.join_condition}\")\n", + "\n", + "print(\"\\n1c. Compare: d.city \\u2192 output.customer_city is a VALUE edge:\")\n", + "for edge in graph_1.edges:\n", + " if edge.from_node.full_name == \"dim_customer.city\":\n", + " print(f\" is_join_predicate = {edge.is_join_predicate} \\u2713 (value flow, not predicate)\")\n", + "\n", + "\n", + "# ============================================================\n", + "# Example 2: Point-in-Time / Range Join (BETWEEN)\n", + "# ============================================================\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Example 2: Point-in-Time Join (BETWEEN)\")\n", + "print(\"=\" * 60)\n", + "\n", + "sql_2 = \"\"\"\n", + "SELECT o.order_id, o.customer_id, o.order_ts, o.amount,\n", + " d.city AS customer_city_at_order\n", + "FROM raw_orders o\n", + "LEFT JOIN dim_customer d\n", + " ON o.customer_id = d.id\n", + " AND o.order_ts BETWEEN d.start_time AND d.end_time\n", + "\"\"\"\n", + "\n", + "builder_2 = RecursiveLineageBuilder(sql_2, dialect=\"bigquery\")\n", + "graph_2 = builder_2.build()\n", + "\n", + "print(f\"\\nQuery:{sql_2}\")\n", + "print(\"2a. All predicate edges \\u2192 customer_city_at_order:\")\n", + "pred_edges_2 = predicate_edges_to(graph_2, \"output.customer_city_at_order\")\n", + "for edge in pred_edges_2:\n", + " print(f\" {edge.from_node.full_name:30s} (join_side={edge.join_side})\")\n", + "\n", + "print(f\"\\n \\u2192 {len(pred_edges_2)} predicate columns detected\")\n", + "print(\" \\u2192 d.start_time and d.end_time were previously INVISIBLE in lineage\")\n", + "\n", + "print(\"\\n2b. Value edge is unchanged:\")\n", + "for edge in graph_2.edges:\n", + " if edge.from_node.full_name == \"dim_customer.city\" and not edge.is_join_predicate:\n", + " print(f\" {edge.from_node.full_name} -> {edge.to_node.full_name} (value edge \\u2713)\")\n", + "\n", + "\n", + "# ============================================================\n", + "# Example 3: Multi-Join Chain with Per-Join Scoping\n", + "# ============================================================\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Example 3: Multi-Join Chain \\u2014 Per-Join Scoping\")\n", + "print(\"=\" * 60)\n", + "\n", + "sql_3 = \"\"\"\n", + "SELECT a.id, b.val, c.label\n", + "FROM table_a a\n", + "INNER JOIN table_b b ON a.id = b.a_id\n", + "INNER JOIN table_c c ON b.id = c.b_id AND b.category = c.category\n", + "\"\"\"\n", + "\n", + "builder_3 = RecursiveLineageBuilder(sql_3, dialect=\"bigquery\")\n", + "graph_3 = builder_3.build()\n", + "\n", + "print(f\"\\nQuery:{sql_3}\")\n", + "print(\"3a. First join predicates (a.id = b.a_id) \\u2192 output.val ONLY:\")\n", + "for edge in predicate_edges_to(graph_3, \"output.val\"):\n", + " print(f\" {edge.from_node.full_name} -> output.val\")\n", + "\n", + "print(\n", + " \"\\n3b. Second join predicates (b.id = c.b_id AND b.category = c.category) \\u2192 output.label ONLY:\"\n", + ")\n", + "for edge in predicate_edges_to(graph_3, \"output.label\"):\n", + " print(f\" {edge.from_node.full_name} -> output.label\")\n", + "\n", + "print(\"\\n3c. No cross-join leakage \\u2014 first join predicates do NOT target output.label:\")\n", + "label_pred_sources = {e.from_node.full_name for e in predicate_edges_to(graph_3, \"output.label\")}\n", + "assert \"table_a.id\" not in label_pred_sources, \"Cross-join leakage detected!\"\n", + "assert \"table_b.a_id\" not in label_pred_sources, \"Cross-join leakage detected!\"\n", + "print(\" \\u2713 table_a.id NOT in output.label predicates\")\n", + "print(\" \\u2713 table_b.a_id NOT in output.label predicates\")\n", + "\n", + "\n", + "# ============================================================\n", + "# Example 4: Impact Analysis with Predicate Edges\n", + "# ============================================================\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Example 4: Impact Analysis with Predicate Edges\")\n", + "print(\"=\" * 60)\n", + "\n", + "print(\"\\nUsing the point-in-time join from Example 2.\")\n", + "\n", + "print(\"\\n4a. Forward trace from dim_customer.start_time (SQLColumnTracer):\")\n", + "tracer = SQLColumnTracer(sql_2, dialect=\"bigquery\")\n", + "forward = tracer.get_forward_lineage([\"dim_customer.start_time\"])\n", + "print(f\" Impacted outputs: {forward['impacted_outputs']}\")\n", + "print(\" \\u2192 customer_city_at_order is now reachable (was invisible before Gap 7)\")\n", + "\n", + "print(\"\\n4b. Filter predicate vs value edges using is_join_predicate:\")\n", + "value_edges_2 = [e for e in graph_2.edges if not e.is_join_predicate]\n", + "pred_edges_all_2 = [e for e in graph_2.edges if e.is_join_predicate]\n", + "print(f\" Value edges: {len(value_edges_2)}\")\n", + "print(f\" Predicate edges: {len(pred_edges_all_2)}\")\n", + "print(f\" Total edges: {len(graph_2.edges)}\")\n", + "\n", + "print(\"\\n4c. Existing value lineage is unchanged:\")\n", + "for edge in value_edges_2:\n", + " print(f\" {edge.from_node.full_name} -> {edge.to_node.full_name}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"JOIN Predicate Lineage Examples Complete!\")\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "markdown", + "id": "d4e5f6a7", + "metadata": {}, + "source": [ + "### Visualize Point-in-Time Join Lineage\n", + "\n", + "Display the column lineage graph for the point-in-time join (Example 2),\n", + "showing both value and predicate edges." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f6a7b8", + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "\n", + "from clgraph import visualize_pipeline_lineage\n", + "\n", + "# Create pipeline for visualization using the point-in-time join\n", + "sql_pit = \"\"\"\n", + "SELECT o.order_id, o.customer_id, o.order_ts, o.amount,\n", + " d.city AS customer_city_at_order\n", + "FROM raw_orders o\n", + "LEFT JOIN dim_customer d\n", + " ON o.customer_id = d.id\n", + " AND o.order_ts BETWEEN d.start_time AND d.end_time\n", + "\"\"\"\n", + "pit_pipeline = Pipeline([(\"pit_join\", sql_pit)], dialect=\"bigquery\")\n", + "\n", + "if shutil.which(\"dot\") is None:\n", + " print(\"\\u26a0\\ufe0f Graphviz not installed. Install with: brew install graphviz\")\n", + "else:\n", + " print(\"Point-in-Time Join \\u2014 Column Lineage (value + predicate edges):\")\n", + " display(visualize_pipeline_lineage(pit_pipeline.column_graph.to_simplified()))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/examples/self_referencing_lineage.ipynb b/examples/self_referencing_lineage.ipynb new file mode 100644 index 0000000..49b6129 --- /dev/null +++ b/examples/self_referencing_lineage.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a1b2c3d4", + "metadata": {}, + "source": [ + "# Self-Referencing Target Lineage\n", + "\n", + "**Example: Tracking lineage when a query reads from a table it also writes to**\n", + "\n", + "\n", + "This example demonstrates how clgraph detects and models self-referencing\n", + "targets -- queries that INSERT/MERGE into a table while also reading from\n", + "that same table (common in SCD Type 2 and CDC patterns).\n", + "\n", + "Key features demonstrated:\n", + "1. Single-statement self-reference detection\n", + "2. Two-step SCD2 pipeline with cross-query self-read edges\n", + "3. Impact analysis through self-read chains\n", + "4. Edge annotations (`statement_order`, `edge_role`)" + ] + }, + { + "cell_type": "markdown", + "id": "e5f6a7b8", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9d0e1f2", + "metadata": {}, + "outputs": [], + "source": [ + "from clgraph import Pipeline\n", + "\n", + "# ============================================================\n", + "# Example 1: Single-Statement Self-Reference\n", + "# ============================================================\n", + "\n", + "print(\"=\" * 60)\n", + "print(\"Example 1: Single-Statement Self-Reference\")\n", + "print(\"=\" * 60)\n", + "\n", + "sql_insert_self_ref = \"\"\"\n", + "INSERT INTO dim_customer\n", + "SELECT s.id, s.name, s.city,\n", + " COALESCE(t.is_active, 'Y') AS is_active\n", + "FROM staging s\n", + "LEFT JOIN dim_customer t ON s.id = t.id\n", + "WHERE t.id IS NULL\n", + "\"\"\"\n", + "\n", + "pipeline_ex1 = Pipeline(\n", + " [(\"insert_new\", sql_insert_self_ref)],\n", + " dialect=\"bigquery\",\n", + ")\n", + "\n", + "# 1a. ParsedQuery exposes self_referenced_tables\n", + "query = list(pipeline_ex1.table_graph.queries.values())[0]\n", + "print(\"\\n1a ParsedQuery.self_referenced_tables\")\n", + "print(f\" destination : {query.destination_table}\")\n", + "print(f\" source_tables: {query.source_tables}\")\n", + "print(f\" self_ref : {query.self_referenced_tables}\")\n", + "\n", + "# 1b. Self-read nodes in the column graph\n", + "self_read_cols = pipeline_ex1.get_self_read_columns(\"dim_customer\")\n", + "print(f\"\\n1b Self-read nodes for dim_customer ({len(self_read_cols)} found):\")\n", + "for col in self_read_cols:\n", + " print(f\" -> {col.full_name} (node_type={col.node_type}, layer={col.layer})\")\n", + "\n", + "# 1c. Compare self-read vs physical output nodes\n", + "output_cols = [\n", + " c\n", + " for c in pipeline_ex1.columns.values()\n", + " if c.table_name == \"dim_customer\" and c.layer == \"output\"\n", + "]\n", + "print(f\"\\n1c Physical output nodes for dim_customer ({len(output_cols)} found):\")\n", + "for col in output_cols:\n", + " print(f\" -> {col.full_name} (node_type={col.node_type}, layer={col.layer})\")\n", + "\n", + "print(\"\\n Self-read nodes represent the PRIOR state read via LEFT JOIN.\")\n", + "print(\" Output nodes represent the NEW rows written by the INSERT.\")\n", + "\n", + "\n", + "# ============================================================\n", + "# Example 2: Two-Step SCD2 (MERGE + INSERT)\n", + "# ============================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Example 2: Two-Step SCD2 Pipeline\")\n", + "print(\"=\" * 60)\n", + "\n", + "# Step 1: Close old rows\n", + "sql_scd2_merge = \"\"\"\n", + "MERGE INTO dim_customer t\n", + "USING staging s ON t.id = s.id AND t.is_active = 'Y'\n", + "WHEN MATCHED AND (t.name <> s.name OR t.city <> s.city) THEN\n", + " UPDATE SET t.end_time = current_timestamp(), t.is_active = 'N'\n", + "\"\"\"\n", + "\n", + "# Step 2: Open new version rows\n", + "sql_scd2_insert = \"\"\"\n", + "INSERT INTO dim_customer\n", + "SELECT s.id, s.name, s.city,\n", + " current_timestamp() AS start_time,\n", + " TIMESTAMP '9999-12-31 00:00:00' AS end_time,\n", + " COALESCE(t.is_active, 'Y') AS is_active\n", + "FROM staging s\n", + "LEFT JOIN dim_customer t ON s.id = t.id AND t.is_active = 'Y'\n", + "WHERE t.id IS NULL OR (t.name <> s.name OR t.city <> s.city)\n", + "\"\"\"\n", + "\n", + "pipeline_scd2 = Pipeline(\n", + " [\n", + " (\"step1_close_rows\", sql_scd2_merge),\n", + " (\"step2_new_versions\", sql_scd2_insert),\n", + " ],\n", + " dialect=\"bigquery\",\n", + ")\n", + "\n", + "# 2a. Topological sort order\n", + "sorted_ids = pipeline_scd2.table_graph.topological_sort()\n", + "print(\"\\n2a Topological sort order:\")\n", + "for i, qid in enumerate(sorted_ids):\n", + " q = pipeline_scd2.table_graph.queries[qid]\n", + " print(f\" {i}. {qid} ({q.operation.value} -> {q.destination_table})\")\n", + "\n", + "# 2b. Cross-query edges connecting Step 1 output to Step 2 self-read\n", + "cross_query_edges = [e for e in pipeline_scd2.edges if e.edge_role == \"cross_query_self_ref\"]\n", + "print(f\"\\n2b Cross-query self-ref edges ({len(cross_query_edges)} found):\")\n", + "for edge in cross_query_edges:\n", + " print(f\" {edge.from_node.full_name}\")\n", + " print(f\" -> {edge.to_node.full_name}\")\n", + " print(f\" edge_role={edge.edge_role}, statement_order={edge.statement_order}\")\n", + "\n", + "# 2c. Prior-state-read edges within Step 2\n", + "prior_state_edges = [e for e in pipeline_scd2.edges if e.edge_role == \"prior_state_read\"]\n", + "print(f\"\\n2c Prior-state-read edges ({len(prior_state_edges)} found):\")\n", + "for edge in prior_state_edges:\n", + " print(f\" {edge.from_node.full_name} -> {edge.to_node.full_name}\")\n", + "\n", + "\n", + "# ============================================================\n", + "# Example 3: Impact Analysis Through Self-Read Chain\n", + "# ============================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Example 3: Impact Analysis Through Self-Read Chain\")\n", + "print(\"=\" * 60)\n", + "\n", + "# 3a. Forward trace: staging.city -> dim_customer.city through both steps\n", + "forward_hits = pipeline_scd2.trace_column_forward(\"staging\", \"city\")\n", + "print(\"\\n3a trace_column_forward('staging', 'city'):\")\n", + "for col in forward_hits:\n", + " print(f\" -> {col.table_name}.{col.column_name} (query={col.query_id})\")\n", + "\n", + "# 3b. Backward trace: dim_customer.is_active shows self-read chain\n", + "backward_hits = pipeline_scd2.trace_column_backward(\"dim_customer\", \"is_active\")\n", + "print(\"\\n3b trace_column_backward('dim_customer', 'is_active'):\")\n", + "for col in backward_hits:\n", + " print(\n", + " f\" -> {col.table_name}.{col.column_name} (query={col.query_id}, type={col.node_type})\"\n", + " )\n", + "\n", + "# 3c. get_self_read_columns API\n", + "sr_cols = pipeline_scd2.get_self_read_columns(\"dim_customer\")\n", + "print(f\"\\n3c get_self_read_columns('dim_customer') ({len(sr_cols)} nodes):\")\n", + "for col in sr_cols:\n", + " print(f\" -> {col.full_name}\")\n", + "\n", + "\n", + "# ============================================================\n", + "# Example 4: Edge Annotations (statement_order, edge_role)\n", + "# ============================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Example 4: Edge Annotations\")\n", + "print(\"=\" * 60)\n", + "\n", + "# 4a. Filter edges by edge_role\n", + "all_roles = {}\n", + "for edge in pipeline_scd2.edges:\n", + " role = edge.edge_role or \"(none)\"\n", + " all_roles.setdefault(role, []).append(edge)\n", + "\n", + "print(\"\\n4a Edge counts by edge_role:\")\n", + "for role, edges in sorted(all_roles.items()):\n", + " print(f\" {role}: {len(edges)} edges\")\n", + "\n", + "# 4b. Inspect prior_state_read edges\n", + "print(\"\\n4b Edges with edge_role='prior_state_read':\")\n", + "for edge in pipeline_scd2.edges:\n", + " if edge.edge_role == \"prior_state_read\":\n", + " print(f\" {edge.from_node.full_name}\")\n", + " print(f\" -> {edge.to_node.full_name}\")\n", + " print(f\" query_id={edge.query_id}, statement_order={edge.statement_order}\")\n", + "\n", + "# 4c. Inspect cross_query_self_ref edges\n", + "print(\"\\n4c Edges with edge_role='cross_query_self_ref':\")\n", + "for edge in pipeline_scd2.edges:\n", + " if edge.edge_role == \"cross_query_self_ref\":\n", + " print(f\" {edge.from_node.full_name}\")\n", + " print(f\" -> {edge.to_node.full_name}\")\n", + " print(f\" statement_order={edge.statement_order} (reflects topological position)\")\n", + "\n", + "# 4d. statement_order on all annotated edges\n", + "annotated_edges = [e for e in pipeline_scd2.edges if e.statement_order is not None]\n", + "print(f\"\\n4d All edges with statement_order set ({len(annotated_edges)} total):\")\n", + "for edge in annotated_edges:\n", + " print(\n", + " f\" order={edge.statement_order} role={edge.edge_role} \"\n", + " f\"{edge.from_node.full_name} -> {edge.to_node.full_name}\"\n", + " )\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Self-Referencing Lineage Examples Complete!\")\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "markdown", + "id": "f3a4b5c6", + "metadata": {}, + "source": [ + "### Visualize SCD2 Pipeline Lineage\n", + "\n", + "Display the simplified column lineage for the two-step SCD2 pipeline,\n", + "showing how self-read nodes connect the MERGE and INSERT steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7e8f9a0", + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "\n", + "from clgraph import visualize_pipeline_lineage\n", + "\n", + "if shutil.which(\"dot\") is None:\n", + " print(\"\\u26a0\\ufe0f Graphviz not installed. Install with: brew install graphviz\")\n", + "else:\n", + " print(\"SCD2 Pipeline - Simplified Lineage:\")\n", + " display(visualize_pipeline_lineage(pipeline_scd2.column_graph.to_simplified()))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}