Skip to content

Commit 17b3af7

Browse files
committed
Support Databricks query tags from session properties
Signed-off-by: Christian Troelsen <christian.troelsen@tryg.dk>
1 parent 3be5bba commit 17b3af7

3 files changed

Lines changed: 177 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ bigquery = [
5151
# pinned an older SQLGlot which is incompatible with SQLMesh
5252
bigframes = ["bigframes>=1.32.0"]
5353
clickhouse = ["clickhouse-connect"]
54-
databricks = ["databricks-sql-connector[pyarrow]"]
54+
databricks = ["databricks-sql-connector[pyarrow]>=4.2.6"]
5555
dev = [
5656
"agate",
5757
"beautifulsoup4",

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,43 @@
3030
logger = logging.getLogger(__name__)
3131

3232

33+
def _query_tags(
34+
query_tags: t.Optional[t.Union[exp.Expr, str, int, float, bool]],
35+
) -> t.Optional[t.Dict[str, t.Optional[str]]]:
36+
if not query_tags:
37+
return None
38+
39+
if not isinstance(query_tags, exp.Map):
40+
raise SQLMeshError("Invalid value for `session_properties.query_tags`. Must be a map.")
41+
42+
keys = query_tags.args.get("keys")
43+
values = query_tags.args.get("values")
44+
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
45+
raise SQLMeshError(
46+
"Invalid value for `session_properties.query_tags`. Must be a map with array "
47+
"keys and array values."
48+
)
49+
50+
tags: t.Dict[str, t.Optional[str]] = {}
51+
for key, value in zip(keys.expressions, values.expressions):
52+
if not isinstance(key, exp.Literal) or not key.is_string:
53+
raise SQLMeshError(
54+
"Invalid key in `session_properties.query_tags`. Keys must be string literals."
55+
)
56+
57+
if isinstance(value, exp.Null):
58+
tags[key.this] = None
59+
elif isinstance(value, exp.Literal) and value.is_string:
60+
tags[key.this] = value.this
61+
else:
62+
raise SQLMeshError(
63+
"Invalid value in `session_properties.query_tags`. Values must be string "
64+
"literals or NULL."
65+
)
66+
67+
return tags
68+
69+
3370
class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
3471
DIALECT = "databricks"
3572
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
@@ -98,6 +135,12 @@ def _use_spark_session(self) -> bool:
98135
def is_spark_session_connection(self) -> bool:
99136
return isinstance(self.connection, SparkSessionConnection)
100137

138+
@property
139+
def _is_databricks_sql_connector_connection(self) -> bool:
140+
return not self.is_spark_session_connection and not self._connection_pool.get_attribute(
141+
"use_spark_engine_adapter"
142+
)
143+
101144
def _set_spark_engine_adapter_if_needed(self) -> None:
102145
self._spark_engine_adapter = None
103146

@@ -181,10 +224,23 @@ def _begin_session(self, properties: SessionProperties) -> t.Any:
181224
"""Begin a new session."""
182225
# Align the different possible connectors to a single catalog
183226
self.set_current_catalog(self.default_catalog) # type: ignore
227+
self._connection_pool.set_attribute("query_tags", _query_tags(properties.get("query_tags")))
184228

185229
def _end_session(self) -> None:
230+
self._connection_pool.set_attribute("query_tags", None)
186231
self._connection_pool.set_attribute("use_spark_engine_adapter", False)
187232

233+
def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None:
234+
query_tags = self._connection_pool.get_attribute("query_tags")
235+
if (
236+
query_tags
237+
and "query_tags" not in kwargs
238+
and self._is_databricks_sql_connector_connection
239+
):
240+
kwargs["query_tags"] = query_tags
241+
242+
return super()._execute(sql, track_rows_processed, **kwargs)
243+
188244
def _df_to_source_queries(
189245
self,
190246
df: DF,

tests/core/engine_adapter/test_databricks.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,23 @@
1010
from sqlmesh.core.engine_adapter import DatabricksEngineAdapter
1111
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
1212
from sqlmesh.core.node import IntervalUnit
13+
from sqlmesh.utils.errors import SQLMeshError
1314
from tests.core.engine_adapter import to_sql_calls
1415

1516
pytestmark = [pytest.mark.databricks, pytest.mark.engine]
1617

1718

19+
def _query_tags_map(*items: t.Optional[str]) -> exp.Map:
20+
return exp.Map(
21+
keys=exp.Array(expressions=[exp.Literal.string(item) for item in items[::2]]),
22+
values=exp.Array(
23+
expressions=[
24+
exp.Null() if item is None else exp.Literal.string(item) for item in items[1::2]
25+
]
26+
),
27+
)
28+
29+
1830
def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
1931
mocker.patch(
2032
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists",
@@ -117,6 +129,114 @@ def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.
117129
assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"]
118130

119131

132+
def test_session_query_tags(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
133+
mocker.patch(
134+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
135+
)
136+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
137+
138+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "app", "sqlmesh")}):
139+
adapter.execute("SELECT 1")
140+
141+
adapter.cursor.execute.assert_called_with(
142+
"SELECT 1", query_tags={"team": "data-eng", "app": "sqlmesh"}
143+
)
144+
145+
adapter.execute("SELECT 2")
146+
147+
adapter.cursor.execute.assert_called_with("SELECT 2")
148+
149+
150+
def test_session_query_tags_allow_none_values(
151+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
152+
):
153+
mocker.patch(
154+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
155+
)
156+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
157+
158+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "feature", None)}):
159+
adapter.execute("SELECT 1")
160+
161+
adapter.cursor.execute.assert_called_with(
162+
"SELECT 1", query_tags={"team": "data-eng", "feature": None}
163+
)
164+
165+
166+
def test_session_query_tags_do_not_override_explicit_query_tags(
167+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
168+
):
169+
mocker.patch(
170+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
171+
)
172+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
173+
174+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
175+
adapter.execute("SELECT 1", query_tags={"team": "analytics"})
176+
177+
adapter.cursor.execute.assert_called_with("SELECT 1", query_tags={"team": "analytics"})
178+
179+
180+
def test_session_query_tags_not_applied_to_spark_session_connection(
181+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
182+
):
183+
mocker.patch(
184+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
185+
)
186+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
187+
mocker.patch.object(
188+
DatabricksEngineAdapter,
189+
"is_spark_session_connection",
190+
new_callable=mocker.PropertyMock,
191+
return_value=True,
192+
)
193+
194+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
195+
adapter.execute("SELECT 1")
196+
197+
adapter.cursor.execute.assert_called_with("SELECT 1")
198+
199+
200+
def test_session_query_tags_not_applied_to_spark_engine_adapter(
201+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
202+
):
203+
mocker.patch(
204+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
205+
)
206+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
207+
spark_cursor = mocker.Mock()
208+
adapter._spark_engine_adapter = mocker.Mock(cursor=spark_cursor)
209+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
210+
211+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
212+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
213+
adapter.execute("SELECT 1")
214+
215+
spark_cursor.execute.assert_called_with("SELECT 1")
216+
217+
218+
@pytest.mark.parametrize(
219+
"query_tags",
220+
[
221+
"team:data-eng",
222+
exp.Map(
223+
keys=exp.Array(expressions=[exp.Literal.number(1)]),
224+
values=exp.Array(expressions=[exp.Literal.string("data-eng")]),
225+
),
226+
exp.Map(
227+
keys=exp.Array(expressions=[exp.Literal.string("team")]),
228+
values=exp.Array(expressions=[exp.Literal.number(1)]),
229+
),
230+
],
231+
)
232+
def test_session_query_tags_invalid(query_tags, make_mocked_engine_adapter: t.Callable):
233+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
234+
235+
with pytest.raises(SQLMeshError, match="session_properties.query_tags"):
236+
with adapter.session({"query_tags": query_tags}):
237+
pass
238+
239+
120240
def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
121241
mocker.patch(
122242
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"

0 commit comments

Comments
 (0)