Skip to content

Commit 2e55877

Browse files
authored
expose a source option for trino (#5672)
1 parent d5ceeb2 commit 2e55877

File tree

4 files changed

+41
-1
lines changed

4 files changed

+41
-1
lines changed

docs/integrations/engines/trino.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ hive.metastore.glue.default-warehouse-dir=s3://my-bucket/
9090
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
9191
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
9292
| `roles` | Mapping of catalog name to a role | dict | N |
93+
| `source` | Value to send as Trino's `source` field for query attribution / auditing. Default: `sqlmesh`. | string | N |
9394
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
9495
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
9596
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |

sqlmesh/core/config/connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,7 @@ class TrinoConnectionConfig(ConnectionConfig):
18881888
client_certificate: t.Optional[str] = None
18891889
client_private_key: t.Optional[str] = None
18901890
cert: t.Optional[str] = None
1891+
source: str = "sqlmesh"
18911892

18921893
# SQLMesh options
18931894
schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None
@@ -1984,6 +1985,7 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
19841985
"port",
19851986
"catalog",
19861987
"roles",
1988+
"source",
19871989
"http_scheme",
19881990
"http_headers",
19891991
"session_properties",
@@ -2041,7 +2043,7 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
20412043
"user": self.impersonation_user or self.user,
20422044
"max_attempts": self.retries,
20432045
"verify": self.cert if self.cert is not None else self.verify,
2044-
"source": "sqlmesh",
2046+
"source": self.source,
20452047
}
20462048

20472049
@property

tests/core/engine_adapter/test_trino.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,18 +412,22 @@ def test_timestamp_mapping():
412412
catalog="catalog",
413413
)
414414

415+
assert config._connection_factory_with_kwargs.keywords["source"] == "sqlmesh"
416+
415417
adapter = config.create_engine_adapter()
416418
assert adapter.timestamp_mapping is None
417419

418420
config = TrinoConnectionConfig(
419421
user="user",
420422
host="host",
421423
catalog="catalog",
424+
source="my_source",
422425
timestamp_mapping={
423426
"TIMESTAMP": "TIMESTAMP(6)",
424427
"TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE",
425428
},
426429
)
430+
assert config._connection_factory_with_kwargs.keywords["source"] == "my_source"
427431
adapter = config.create_engine_adapter()
428432
assert adapter.timestamp_mapping is not None
429433
assert adapter.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build(

tests/core/test_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,39 @@ def test_trino_schema_location_mapping_syntax(tmp_path):
862862
assert len(conn.schema_location_mapping) == 2
863863

864864

865+
def test_trino_source_option(tmp_path):
866+
config_path = tmp_path / "config_trino_source.yaml"
867+
with open(config_path, "w", encoding="utf-8") as fd:
868+
fd.write(
869+
"""
870+
gateways:
871+
trino:
872+
connection:
873+
type: trino
874+
user: trino
875+
host: trino
876+
catalog: trino
877+
source: my_sqlmesh_source
878+
879+
default_gateway: trino
880+
881+
model_defaults:
882+
dialect: trino
883+
"""
884+
)
885+
886+
config = load_config_from_paths(
887+
Config,
888+
project_paths=[config_path],
889+
)
890+
891+
from sqlmesh.core.config.connection import TrinoConnectionConfig
892+
893+
conn = config.gateways["trino"].connection
894+
assert isinstance(conn, TrinoConnectionConfig)
895+
assert conn.source == "my_sqlmesh_source"
896+
897+
865898
def test_gcp_postgres_ip_and_scopes(tmp_path):
866899
config_path = tmp_path / "config_gcp_postgres.yaml"
867900
with open(config_path, "w", encoding="utf-8") as fd:

0 commit comments

Comments
 (0)