Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions sqlmesh/core/engine_adapter/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as t

from sqlglot import exp
from sqlglot.helper import ensure_list

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
Expand All @@ -30,6 +31,7 @@

from sqlmesh.core._typing import SchemaName, TableName
from sqlmesh.core.engine_adapter.base import QueryOrDF, Query
from sqlmesh.core.node import IntervalUnit

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,6 +251,63 @@ def create_view(
**create_kwargs,
)

def _build_table_properties_exp(
self,
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
clustered_by: t.Optional[t.List[exp.Expr]] = None,
table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
properties: t.List[exp.Expr] = []

if table_description:
properties.append(
exp.SchemaCommentProperty(
this=exp.Literal.string(self._truncate_table_comment(table_description))
)
)

def _to_identifier_if_string(expression: exp.Expr) -> exp.Expr:
if isinstance(expression, exp.Literal) and expression.is_string:
return exp.to_identifier(expression.this)
return expression.copy()

if table_properties:
table_properties = {k.upper(): v for k, v in table_properties.items()}

table_type = self._pop_creatable_type_from_properties(table_properties)
properties.extend(ensure_list(table_type))

diststyle = table_properties.get("DISTSTYLE")
if diststyle:
properties.append(exp.DistStyleProperty(this=exp.var(diststyle.name.upper())))

distkey = table_properties.get("DISTKEY")
if distkey:
properties.append(exp.DistKeyProperty(this=_to_identifier_if_string(distkey)))

sortkey = table_properties.get("SORTKEY")
if sortkey:
sortkey_expressions = sortkey.expressions if sortkey.expressions else [sortkey]
properties.append(
exp.SortKeyProperty(
this=[
_to_identifier_if_string(expression)
for expression in sortkey_expressions
],
compound=False,
)
)

return exp.Properties(expressions=properties) if properties else None

def replace_query(
self,
table_name: TableName,
Expand Down
155 changes: 155 additions & 0 deletions tests/core/engine_adapter/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from sqlglot import expressions as exp
from sqlglot import parse_one

import sqlmesh.core.dialect as d
from sqlmesh.core.engine_adapter import RedshiftEngineAdapter
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
from sqlmesh.core.model import load_sql_based_model
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.utils.errors import SQLMeshError
from tests.core.engine_adapter import to_sql_calls

Expand All @@ -32,6 +35,158 @@ def test_columns(adapter: t.Callable):
assert resp == {"col": exp.DataType.build("INT")}


def test_create_table_physical_properties(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)

adapter.create_table(
"test_schema.test_table",
{
"id_file": exp.DataType.build("INT"),
"batch_time": exp.DataType.build("TIMESTAMP"),
},
table_properties={
"diststyle": exp.column("key"),
"distkey": exp.to_column("id_file"),
"sortkey": exp.to_column("batch_time"),
},
)

assert to_sql_calls(adapter) == [
'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time")',
]


@pytest.mark.parametrize(
("diststyle", "expected"),
[
("auto", "AUTO"),
("even", "EVEN"),
("key", "KEY"),
("all", "ALL"),
],
)
def test_create_table_physical_properties_diststyle_values(
make_mocked_engine_adapter: t.Callable,
diststyle: str,
expected: str,
):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)
table_properties = {"diststyle": exp.column(diststyle)}
if diststyle == "key":
table_properties["distkey"] = exp.to_column("id_file")

adapter.create_table(
"test_schema.test_table",
{"id_file": exp.DataType.build("INT")},
table_properties=table_properties,
)

expected_distkey = ' DISTKEY("id_file")' if diststyle == "key" else ""
assert to_sql_calls(adapter) == [
f'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER) DISTSTYLE {expected}{expected_distkey}',
]


def test_create_table_physical_properties_distkey_without_diststyle(
make_mocked_engine_adapter: t.Callable,
):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)

adapter.create_table(
"test_schema.test_table",
{"id_file": exp.DataType.build("INT")},
table_properties={"distkey": exp.to_column("id_file")},
)

assert to_sql_calls(adapter) == [
'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER) DISTKEY("id_file")',
]


def test_create_table_physical_properties_multi_column_sortkey(
make_mocked_engine_adapter: t.Callable,
):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)

adapter.create_table(
"test_schema.test_table",
{
"id_file": exp.DataType.build("INT"),
"batch_time": exp.DataType.build("TIMESTAMP"),
"event_time": exp.DataType.build("TIMESTAMP"),
},
table_properties={
"diststyle": exp.column("key"),
"distkey": exp.to_column("id_file"),
"sortkey": exp.Tuple(
expressions=[exp.to_column("batch_time"), exp.to_column("event_time")]
),
},
)

assert to_sql_calls(adapter) == [
'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP, "event_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time", "event_time")',
]


def test_create_table_physical_properties_with_string_columns(
make_mocked_engine_adapter: t.Callable,
):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)

adapter.create_table(
"test_schema.test_table",
{
"id_file": exp.DataType.build("INT"),
"batch_time": exp.DataType.build("TIMESTAMP"),
},
table_properties={
"diststyle": exp.Literal.string("key"),
"distkey": exp.Literal.string("id_file"),
"sortkey": exp.Literal.string("batch_time"),
},
)

assert to_sql_calls(adapter) == [
'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time")',
]


def test_create_table_physical_properties_from_model_definition(
make_mocked_engine_adapter: t.Callable,
):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)
model: SqlModel = t.cast(
SqlModel,
load_sql_based_model(
d.parse(
"""
MODEL (
name test_schema.test_table,
kind full,
physical_properties (
diststyle = key,
distkey = "id_file",
sortkey = "batch_time"
)
);
SELECT id_file::INT, batch_time::TIMESTAMP;
"""
)
),
)

adapter.create_table(
model.name,
target_columns_to_types=model.columns_to_types_or_raise,
table_properties=model.physical_properties,
)

assert to_sql_calls(adapter) == [
'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time")',
]


def test_varchar_size_workaround(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)

Expand Down