Skip to content

Commit 7d69f1c

Browse files
authored
Chore!: Reintroduce tagging queries with correlation ID (#4895)
1 parent 7fb652e commit 7d69f1c

File tree

6 files changed

+69
-17
lines changed

6 files changed

+69
-17
lines changed

sqlmesh/core/context.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def snapshot_evaluator(self) -> SnapshotEvaluator:
453453
if not self._snapshot_evaluator:
454454
self._snapshot_evaluator = SnapshotEvaluator(
455455
{
456-
gateway: adapter.with_settings(log_level=logging.INFO)
456+
gateway: adapter.with_settings(execute_log_level=logging.INFO)
457457
for gateway, adapter in self.engine_adapters.items()
458458
},
459459
ddl_concurrent_tasks=self.concurrent_tasks,
@@ -520,7 +520,11 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
520520

521521
return model
522522

523-
def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
523+
def scheduler(
524+
self,
525+
environment: t.Optional[str] = None,
526+
snapshot_evaluator: t.Optional[SnapshotEvaluator] = None,
527+
) -> Scheduler:
524528
"""Returns the built-in scheduler.
525529
526530
Args:
@@ -542,9 +546,11 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
542546
if not snapshots:
543547
raise ConfigError("No models were found")
544548

545-
return self.create_scheduler(snapshots)
549+
return self.create_scheduler(snapshots, snapshot_evaluator or self.snapshot_evaluator)
546550

547-
def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
551+
def create_scheduler(
552+
self, snapshots: t.Iterable[Snapshot], snapshot_evaluator: SnapshotEvaluator
553+
) -> Scheduler:
548554
"""Creates the built-in scheduler.
549555
550556
Args:
@@ -555,7 +561,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
555561
"""
556562
return Scheduler(
557563
snapshots,
558-
self.snapshot_evaluator,
564+
snapshot_evaluator,
559565
self.state_sync,
560566
default_catalog=self.default_catalog,
561567
max_workers=self.concurrent_tasks,
@@ -1931,7 +1937,7 @@ def _table_diff(
19311937
)
19321938

19331939
return TableDiff(
1934-
adapter=adapter.with_settings(logger.getEffectiveLevel()),
1940+
adapter=adapter.with_settings(execute_log_level=logger.getEffectiveLevel()),
19351941
source=source,
19361942
target=target,
19371943
on=on,

sqlmesh/core/engine_adapter/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,23 @@ def __init__(
147147
self._multithreaded = multithreaded
148148
self.correlation_id = correlation_id
149149

150-
def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter:
150+
def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
151+
extra_kwargs = {
152+
"null_connection": True,
153+
"execute_log_level": kwargs.pop("execute_log_level", self._execute_log_level),
154+
**self._extra_config,
155+
**kwargs,
156+
}
157+
151158
adapter = self.__class__(
152159
self._connection_pool,
153160
dialect=self.dialect,
154161
sql_gen_kwargs=self._sql_gen_kwargs,
155162
default_catalog=self._default_catalog,
156-
execute_log_level=log_level,
157163
register_comments=self._register_comments,
158-
null_connection=True,
159164
multithreaded=self._multithreaded,
160165
pretty_sql=self._pretty_sql,
161-
**self._extra_config,
162-
**kwargs,
166+
**extra_kwargs,
163167
)
164168

165169
return adapter

sqlmesh/core/plan/evaluator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from sqlmesh.utils import to_snake_case
4040
from sqlmesh.core.state_sync import StateSync
41+
from sqlmesh.utils import CorrelationId
4142
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4243
from sqlmesh.utils.errors import PlanError, SQLMeshError
4344
from sqlmesh.utils.dag import DAG
@@ -71,7 +72,7 @@ def __init__(
7172
self,
7273
state_sync: StateSync,
7374
snapshot_evaluator: SnapshotEvaluator,
74-
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
75+
create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler],
7576
default_catalog: t.Optional[str],
7677
console: t.Optional[Console] = None,
7778
):
@@ -88,6 +89,9 @@ def evaluate(
8889
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
8990
) -> None:
9091
self._circuit_breaker = circuit_breaker
92+
self.snapshot_evaluator = self.snapshot_evaluator.set_correlation_id(
93+
CorrelationId.from_plan_id(plan.plan_id)
94+
)
9195

9296
self.console.start_plan_evaluation(plan)
9397
analytics.collector.on_plan_apply_start(
@@ -106,6 +110,7 @@ def evaluate(
106110
else:
107111
analytics.collector.on_plan_apply_end(plan_id=plan.plan_id)
108112
finally:
113+
self.snapshot_evaluator.recycle()
109114
self.console.stop_plan_evaluation()
110115

111116
def _evaluate_stages(
@@ -228,7 +233,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
228233
self.console.log_success("SKIP: No model batches to execute")
229234
return
230235

231-
scheduler = self.create_scheduler(stage.all_snapshots.values())
236+
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
232237
errors, _ = scheduler.run_merged_intervals(
233238
merged_intervals=stage.snapshot_to_intervals,
234239
deployability_index=stage.deployability_index,
@@ -249,7 +254,7 @@ def visit_audit_only_run_stage(
249254
return
250255

251256
# If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them
252-
scheduler = self.create_scheduler(audit_snapshots)
257+
scheduler = self.create_scheduler(audit_snapshots, self.snapshot_evaluator)
253258
completion_status = scheduler.audit(
254259
plan.environment,
255260
plan.start,

sqlmesh/core/snapshot/evaluator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
SnapshotTableCleanupTask,
6262
)
6363
from sqlmesh.core.snapshot.definition import parent_snapshots_by_name
64-
from sqlmesh.utils import random_id
64+
from sqlmesh.utils import random_id, CorrelationId
6565
from sqlmesh.utils.concurrency import (
6666
concurrent_apply_to_snapshots,
6767
concurrent_apply_to_values,
@@ -127,6 +127,7 @@ def __init__(
127127
if not selected_gateway
128128
else self.adapters[selected_gateway]
129129
)
130+
self.selected_gateway = selected_gateway
130131
self.ddl_concurrent_tasks = ddl_concurrent_tasks
131132

132133
def evaluate(
@@ -1186,6 +1187,16 @@ def _execute_create(
11861187
)
11871188
adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs))
11881189

1190+
def set_correlation_id(self, correlation_id: CorrelationId) -> SnapshotEvaluator:
1191+
return SnapshotEvaluator(
1192+
{
1193+
gateway: adapter.with_settings(correlation_id=correlation_id)
1194+
for gateway, adapter in self.adapters.items()
1195+
},
1196+
self.ddl_concurrent_tasks,
1197+
self.selected_gateway,
1198+
)
1199+
11891200

11901201
def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy:
11911202
klass: t.Type

tests/core/test_integration.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
7272
from sqlmesh.utils.pydantic import validate_string
7373
from tests.conftest import DuckDBMetadata, SushiDataValidator
74+
from sqlmesh.utils import CorrelationId
7475
from tests.utils.test_helpers import use_terminal_console
7576
from tests.utils.test_filesystem import create_temp_file
7677

@@ -6815,3 +6816,28 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call
68156816
# valid_from should be the epoch, valid_to should be NaT
68166817
assert str(row["valid_from"]) == "1970-01-01 00:00:00"
68176818
assert pd.isna(row["valid_to"])
6819+
6820+
6821+
def test_plan_evaluator_correlation_id(tmp_path: Path):
6822+
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
6823+
sqls = [call[0][0] for call in mock_logger.call_args_list]
6824+
return any(f"/* {correlation_id} */" in sql for sql in sqls)
6825+
6826+
ctx = Context(paths=[tmp_path], config=Config())
6827+
6828+
# Case: Ensure that the correlation id (plan_id) is included in the SQL for each plan
6829+
for i in range(2):
6830+
create_temp_file(
6831+
tmp_path,
6832+
Path("models", "test.sql"),
6833+
f"MODEL (name test.a, kind FULL); SELECT {i} AS col",
6834+
)
6835+
6836+
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
6837+
ctx.load()
6838+
plan = ctx.plan(auto_apply=True, no_prompts=True)
6839+
6840+
correlation_id = CorrelationId.from_plan_id(plan.plan_id)
6841+
assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}"
6842+
6843+
assert _correlation_id_in_sqls(correlation_id, mock_logger)

tests/core/test_table_diff.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,9 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
337337

338338
# make with_settings() return the current instance of engine_adapter so we can still spy on _execute
339339
mocker.patch.object(
340-
engine_adapter, "with_settings", new_callable=lambda: lambda _: engine_adapter
340+
engine_adapter, "with_settings", new_callable=lambda: lambda **kwargs: engine_adapter
341341
)
342-
assert engine_adapter.with_settings(1) == engine_adapter
342+
assert engine_adapter.with_settings() == engine_adapter
343343

344344
spy_execute = mocker.spy(engine_adapter, "_execute")
345345
mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="abcdefgh")

0 commit comments

Comments
 (0)