Skip to content
Merged
510 changes: 497 additions & 13 deletions examples/merge_lineage.ipynb

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions src/clgraph/column_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,34 @@ def extract_merge_columns(ctx: ExtractionContext, unit: QueryUnit) -> List[Dict]
output_cols.append(col_info)
idx += 1

# 1b. Literal-bound match filter columns (edges for ON clause literal predicates)
match_filter_columns = config.get("match_filter_columns", [])
for col_name, literal_val in match_filter_columns:
col_info = {
"index": idx,
"name": col_name,
"is_star": False,
"type": "merge_match_filter",
"expression": f"{target_alias}.{col_name} = {literal_val}",
"ast_node": None,
"source_columns": [(target_alias, col_name)],
"merge_action": "match",
"merge_column_role": "condition",
}
output_cols.append(col_info)
idx += 1

# 2. WHEN MATCHED -> UPDATE columns
for action in matched_actions:
if action.get("action_type") == "update":
condition = action.get("condition")
# Note: target_alias is used as default_table, but WHEN conditions
# typically use qualified refs (t.name, s.name). extract_columns_from_expr
# uses the qualified table ref when present, so the default_table only
# applies to unqualified column names.
condition_columns = (
extract_columns_from_expr(condition, target_alias) if condition else []
)
for target_col, source_expr in action.get("column_mappings", {}).items():
col_info = {
"index": idx,
Expand All @@ -299,6 +323,7 @@ def extract_merge_columns(ctx: ExtractionContext, unit: QueryUnit) -> List[Dict]
"source_columns": extract_columns_from_expr(source_expr, source_alias),
"merge_action": "update",
"merge_condition": condition,
"condition_columns": condition_columns,
}
output_cols.append(col_info)
idx += 1
Expand All @@ -307,6 +332,11 @@ def extract_merge_columns(ctx: ExtractionContext, unit: QueryUnit) -> List[Dict]
for action in not_matched_actions:
if action.get("action_type") == "insert":
condition = action.get("condition")
# Note: source_alias is used as default_table (not target_alias) because
# NOT MATCHED conditions reference source rows (target row doesn't exist).
condition_columns = (
extract_columns_from_expr(condition, source_alias) if condition else []
)
for target_col, source_expr in action.get("column_mappings", {}).items():
col_info = {
"index": idx,
Expand All @@ -318,6 +348,7 @@ def extract_merge_columns(ctx: ExtractionContext, unit: QueryUnit) -> List[Dict]
"source_columns": extract_columns_from_expr(source_expr, source_alias),
"merge_action": "insert",
"merge_condition": condition,
"condition_columns": condition_columns,
}
output_cols.append(col_info)
idx += 1
Expand Down
1 change: 1 addition & 0 deletions src/clgraph/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def export(
edge_dict["is_merge_operation"] = True
edge_dict["merge_action"] = getattr(edge, "merge_action", None)
edge_dict["merge_condition"] = getattr(edge, "merge_condition", None)
edge_dict["merge_column_role"] = getattr(edge, "merge_column_role", None)

# Include QUALIFY clause metadata if present
if getattr(edge, "is_qualify_column", False):
Expand Down
7 changes: 6 additions & 1 deletion src/clgraph/lineage_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,12 @@ def _trace_column_dependencies(self, unit: QueryUnit, output_node: ColumnNode, c
return

# Branch 4: MERGE
if col_info.get("type") in ("merge_match", "merge_update", "merge_insert"):
if col_info.get("type") in (
"merge_match",
"merge_update",
"merge_insert",
"merge_match_filter",
):
trace_merge_columns(
self.lineage_graph,
unit,
Expand Down
3 changes: 3 additions & 0 deletions src/clgraph/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ class ColumnEdge:
is_merge_operation: bool = False # True if this edge is from a MERGE statement
merge_action: Optional[str] = None # "match", "update", "insert", "delete"
merge_condition: Optional[str] = None # Condition for conditional WHEN clauses
merge_column_role: Optional[str] = (
None # None (match), "value" (SET RHS), or "condition" (WHEN guard / ON filter)
)

# ─── QUALIFY Clause Metadata ───
is_qualify_column: bool = False # True if this column is used in QUALIFY clause
Expand Down
1 change: 1 addition & 0 deletions src/clgraph/pipeline_lineage_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def _add_query_edges(
is_merge_operation=getattr(edge, "is_merge_operation", False),
merge_action=getattr(edge, "merge_action", None),
merge_condition=getattr(edge, "merge_condition", None),
merge_column_role=getattr(edge, "merge_column_role", None),
# Preserve QUALIFY clause metadata
is_qualify_column=getattr(edge, "is_qualify_column", False),
qualify_context=getattr(edge, "qualify_context", None),
Expand Down
6 changes: 6 additions & 0 deletions src/clgraph/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,17 @@ def _parse_merge_statement(

# Extract match columns from ON condition
match_columns: List[Tuple[str, str]] = []
match_filter_columns: List[Tuple[str, str]] = []
if match_condition:
for eq in match_condition.find_all(exp.EQ):
left_col = eq.left
right_col = eq.right
if isinstance(left_col, exp.Column) and isinstance(right_col, exp.Column):
match_columns.append((left_col.name, right_col.name))
elif isinstance(left_col, exp.Column) and not isinstance(right_col, exp.Column):
match_filter_columns.append((left_col.name, right_col.sql()))
elif isinstance(right_col, exp.Column) and not isinstance(left_col, exp.Column):
match_filter_columns.append((right_col.name, left_col.sql()))

# Parse WHEN clauses from the 'whens' arg
whens = merge_node.args.get("whens")
Expand Down Expand Up @@ -698,6 +703,7 @@ def _parse_merge_statement(
"source_alias": source_alias,
"match_condition": match_condition_sql,
"match_columns": match_columns,
"match_filter_columns": match_filter_columns,
"matched_actions": matched_actions,
"not_matched_actions": not_matched_actions,
}
Expand Down
39 changes: 33 additions & 6 deletions src/clgraph/trace_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,32 @@ def trace_merge_columns(
merge_action = col_info.get("merge_action", col_info.get("type"))
merge_condition = col_info.get("merge_condition")
source_refs = col_info.get("source_columns", [])
condition_refs = col_info.get("condition_columns", [])

for source_ref in source_refs:
table_ref, col_name = source_ref[:2]

# Try to resolve as a source unit or base table
def _resolve_to_node(table_ref, col_name):
"""Resolve a (table, column) ref to a ColumnNode."""
source_node = None
source_unit = resolve_source_unit(unit, table_ref) if table_ref else None
if source_unit:
source_node = find_column_in_unit(source_unit, col_name)
if not source_node:
# Try as base table
base_table = resolve_base_table_name(unit, table_ref) if table_ref else None
if base_table:
source_node = find_or_create_table_column_node(graph, base_table, col_name)
elif table_ref:
# Fallback: use table_ref directly
source_node = find_or_create_table_column_node(graph, table_ref, col_name)
return source_node

# Value-assignment edges (RHS of SET) or match/filter edges
for source_ref in source_refs:
table_ref, col_name = source_ref[:2]
source_node = _resolve_to_node(table_ref, col_name)
if source_node:
# Determine role: match edges get None, update/insert get "value",
# merge_match_filter edges keep their explicit "condition" role
role = col_info.get("merge_column_role")
if role is None and col_info["type"] in ("merge_update", "merge_insert"):
role = "value"
edge = ColumnEdge(
from_node=source_node,
to_node=output_node,
Expand All @@ -225,6 +232,26 @@ def trace_merge_columns(
is_merge_operation=True,
merge_action=merge_action,
merge_condition=merge_condition,
merge_column_role=role,
)
graph.add_edge(edge)

# Condition-gating edges (from WHEN AND clause)
for cond_ref in condition_refs:
table_ref, col_name = cond_ref[:2]
source_node = _resolve_to_node(table_ref, col_name)
if source_node:
edge = ColumnEdge(
from_node=source_node,
to_node=output_node,
edge_type=col_info["type"],
transformation=col_info["type"],
context=unit.unit_type.value,
expression=merge_condition,
is_merge_operation=True,
merge_action=merge_action,
merge_condition=merge_condition,
merge_column_role="condition",
)
graph.add_edge(edge)

Expand Down
Loading
Loading