Skip to content

Commit 9d2ecea

Browse files
Fix(databricks): Get correct datatypes from information_schema table (#5783)
Signed-off-by: Bjarke Enkelund <47357343+MisterWheatley@users.noreply.github.com>
1 parent 4dac2b3 commit 9d2ecea

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,26 @@ def _build_column_defs(
411411
return super()._build_column_defs(
412412
target_columns_to_types, column_descriptions, is_view, materialized
413413
)
414+
415+
def columns(
416+
self, table_name: TableName, include_pseudo_columns: bool = False
417+
) -> t.Dict[str, exp.DataType]:
418+
table = exp.to_table(table_name)
419+
420+
column_catalog = table.catalog or self.get_current_catalog()
421+
query = (
422+
exp.select("columns.column_name", "columns.full_data_type")
423+
.from_("system.information_schema.columns")
424+
.where(
425+
exp.and_(
426+
exp.column("table_name").eq(table.name),
427+
exp.column("table_schema").eq(table.db),
428+
exp.column("table_catalog").eq(column_catalog),
429+
)
430+
)
431+
.order_by("ordinal_position ASC")
432+
)
433+
434+
result = self.cursor.fetchall(query)
435+
436+
return {row[0]: exp.DataType.build(row[1], dialect=self.dialect) for row in result}

tests/core/engine_adapter/test_databricks.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,61 @@ def test_drop_data_object_materialized_view_calls_correct_drop(mocker: MockFixtu
526526
drop_view_mock.assert_called_once_with(
527527
mv_data_object.to_table(), ignore_if_not_exists=True, materialized=True
528528
)
529+
530+
531+
def test_columns(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
532+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
533+
534+
# Override/mock get_current_catalog to return default
535+
current_catalog_mock = mocker.patch.object(
536+
adapter, "get_current_catalog", return_value="test_catalog"
537+
)
538+
# create long struct columns datatype
539+
long_struct_cols = [f"a_{i}:int" for i in range(50)]
540+
adapter.cursor.fetchall.return_value = [
541+
("bigint_col", "bigint"),
542+
("binary_col", "binary"),
543+
("boolean_col", "boolean"),
544+
("date_col", "date"),
545+
("decimal_col", "decimal(38,4)"),
546+
("double_col", "double"),
547+
("float_col", "float"),
548+
("int_col", "int"),
549+
("small_int", "smallint"),
550+
("string_col", "string"),
551+
("timestamp_col", "timestamp"),
552+
("timestamp_ntz_col", "timestamp_ntz"),
553+
("tinyint_col", "tinyint"),
554+
("array_col", "array<int>"),
555+
("simple_struct_col", "struct<a:int,b:string>"),
556+
("long_struct_col", f"struct<{','.join(long_struct_cols)}>"),
557+
]
558+
559+
resp = adapter.columns("test_db.test_table")
560+
assert resp == {
561+
"bigint_col": exp.DataType.build("bigint", dialect=adapter.dialect),
562+
"binary_col": exp.DataType.build("binary", dialect=adapter.dialect),
563+
"boolean_col": exp.DataType.build("boolean", dialect=adapter.dialect),
564+
"date_col": exp.DataType.build("date", dialect=adapter.dialect),
565+
"decimal_col": exp.DataType.build("decimal(38,4)", dialect=adapter.dialect),
566+
"double_col": exp.DataType.build("double", dialect=adapter.dialect),
567+
"float_col": exp.DataType.build("float", dialect=adapter.dialect),
568+
"int_col": exp.DataType.build("int", dialect=adapter.dialect),
569+
"small_int": exp.DataType.build("smallint", dialect=adapter.dialect),
570+
"string_col": exp.DataType.build("string", dialect=adapter.dialect),
571+
"timestamp_col": exp.DataType.build("timestamp", dialect=adapter.dialect),
572+
"timestamp_ntz_col": exp.DataType.build("timestamp_ntz", dialect=adapter.dialect),
573+
"tinyint_col": exp.DataType.build("tinyint", dialect=adapter.dialect),
574+
"array_col": exp.DataType.build("array<int>", dialect=adapter.dialect),
575+
"simple_struct_col": exp.DataType.build("struct<a:int,b:string>", dialect=adapter.dialect),
576+
"long_struct_col": exp.DataType.build(
577+
f"struct<{','.join(long_struct_cols)}>", dialect=adapter.dialect
578+
),
579+
}
580+
581+
adapter.cursor.fetchall.assert_called_once_with(
582+
parse_one(
583+
"""SELECT columns.column_name, columns.full_data_type FROM system.information_schema.columns WHERE table_name = 'test_table' AND table_schema = 'test_db' AND table_catalog = 'test_catalog' ORDER BY ordinal_position ASC""",
584+
dialect="databricks",
585+
)
586+
)

0 commit comments

Comments
 (0)