Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,7 +2229,7 @@ def _builder(dtype: exp.DataType) -> exp.DataType:
return dtype

params = f"{precision}{f', {scale}' if scale is not None else ''}"
return exp.DataType.build(f"DECIMAL({params})")
return exp.DataType.from_str(f"DECIMAL({params})")

return _builder

Expand Down
61 changes: 40 additions & 21 deletions sqlglot/expressions/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from typing_extensions import Self


class DataTypeParam(Expression):
Expand Down Expand Up @@ -330,7 +331,7 @@ def build(
udt: bool = False,
copy: bool = True,
**kwargs: object,
) -> DataType:
) -> Self:
"""
Constructs a DataType object.

Expand All @@ -345,32 +346,50 @@ def build(
Returns:
The constructed DataType object.
"""
from sqlglot import parse_one

if isinstance(dtype, str):
if dtype.upper() == "UNKNOWN":
return DataType(this=DType.UNKNOWN, **kwargs)

try:
data_type_exp = parse_one(
dtype, read=dialect, into=DataType, error_level=ErrorLevel.IGNORE
)
except ParseError:
if udt:
return DataType(this=DType.USERDEFINED, kind=dtype, **kwargs)
raise
return cls.from_str(dtype, dialect, udt, **kwargs)
elif isinstance(dtype, DType):
data_type_exp = DataType(this=dtype)
data_type_exp = cls(this=dtype)
if kwargs:
for k, v in kwargs.items():
data_type_exp.set(k, v)
return data_type_exp
elif isinstance(dtype, (Identifier, Dot)) and udt:
return DataType(this=DType.USERDEFINED, kind=dtype, **kwargs)
elif isinstance(dtype, DataType):
return cls(this=DType.USERDEFINED, kind=dtype, **kwargs)
elif isinstance(dtype, cls):
return maybe_copy(dtype, copy)
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DType")
if kwargs:
for k, v in kwargs.items():
data_type_exp.set(k, v)
return data_type_exp

@classmethod
def from_str(
cls, dtype: str, dialect: DialectType = None, udt: bool = False, **kwargs: object
) -> Self:
"""
Constructs a `DataType` object from a `str` representation.

Args:
dtype: the data type of interest.
dialect: the dialect to use for parsing `dtype`.
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
`DataType`, thus creating a user-defined type.
kwargs: additional arguments to pass in the constructor of `DataType`.

Returns:
The constructed `DataType` object.
"""
from sqlglot import parse_one

if dtype.upper() == "UNKNOWN":
return cls(this=DType.UNKNOWN, **kwargs)
try:
return parse_one(
dtype, read=dialect, into=cls, error_level=ErrorLevel.IGNORE
).set_kwargs(kwargs)
except ParseError:
if udt:
return cls(this=DType.USERDEFINED, kind=dtype, **kwargs)
raise

def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool:
"""
Expand Down
20 changes: 10 additions & 10 deletions sqlglot/generators/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _to_boolean_sql(self: DuckDBGenerator, expression: exp.ToBoolean) -> str:
case_expr = base_case_expr.else_(exp.func("TRY_CAST", arg, exp.DType.BOOLEAN.into_expr()))
else:
# TO_BOOLEAN: handle NaN/INF errors, 'on'/'off', and use regular CAST
cast_to_real = exp.func("TRY_CAST", arg, exp.DataType.build(exp.DType.FLOAT))
cast_to_real = exp.func("TRY_CAST", arg, exp.DType.FLOAT.into_expr())

# Check for NaN and INF values
nan_inf_check = exp.Or(
Expand Down Expand Up @@ -891,7 +891,7 @@ def _generate_datetime_array_sql(
if is_generate_date_array:
# The GENERATE_SERIES result type is TIMESTAMP array, so to match BQ's semantics for
# GENERATE_DATE_ARRAY we must cast it back to DATE array
gen_series = exp.cast(gen_series, exp.DataType.build("ARRAY<DATE>"))
gen_series = exp.cast(gen_series, exp.DataType.from_str("ARRAY<DATE>"))

return self.sql(gen_series)

Expand All @@ -901,7 +901,7 @@ def _json_extract_value_array_sql(
) -> str:
json_extract = exp.JSONExtract(this=expression.this, expression=expression.expression)
data_type = "ARRAY<STRING>" if isinstance(expression, exp.JSONValueArray) else "ARRAY<JSON>"
return self.sql(exp.cast(json_extract, to=exp.DataType.build(data_type)))
return self.sql(exp.cast(json_extract, to=exp.DataType.from_str(data_type)))


def _cast_to_varchar(arg: exp.Expr | None) -> exp.Expr | None:
Expand All @@ -926,7 +926,7 @@ def _is_binary(arg: exp.Expr) -> bool:

def _gen_with_cast_to_blob(self: DuckDBGenerator, expression: exp.Expr, result_sql: str) -> str:
if _is_binary(expression):
blob = exp.DataType.build("BLOB", dialect="duckdb")
blob = exp.DataType.from_str("BLOB", dialect="duckdb")
result_sql = self.sql(exp.Cast(this=result_sql, to=blob))
return result_sql

Expand Down Expand Up @@ -1199,7 +1199,7 @@ def _bitshift_sql(

if result_is_blob:
result_sql = self.sql(
exp.Cast(this=result_sql, to=exp.DataType.build("BLOB", dialect="duckdb"))
exp.Cast(this=result_sql, to=exp.DataType.from_str("BLOB", dialect="duckdb"))
)

return result_sql
Expand Down Expand Up @@ -1663,7 +1663,7 @@ class DuckDBGenerator(generator.Generator):
exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)),
exp.ToBoolean: _to_boolean_sql,
exp.ToVariant: lambda self, e: self.sql(
exp.cast(e.this, exp.DataType.build("VARIANT", dialect="duckdb"))
exp.cast(e.this, exp.DataType.from_str("VARIANT", dialect="duckdb"))
),
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: (
Expand Down Expand Up @@ -2536,7 +2536,7 @@ def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
this = expression.this
time_format = self.format_time(expression)
safe = expression.args.get("safe")
time_type = exp.DataType.build("TIME", dialect="duckdb")
time_type = exp.DataType.from_str("TIME", dialect="duckdb")
cast_expr = exp.TryCast if safe else exp.Cast

if time_format:
Expand Down Expand Up @@ -2760,7 +2760,7 @@ def extract_sql(self, expression: exp.Extract) -> str:
this=exp.Extract(this=exp.var("MICROSECOND"), expression=datetime_expr),
expression=exp.Literal.number(1000),
),
exp.DataType.build(cast_type, dialect="duckdb"),
exp.DataType.from_str(cast_type, dialect="duckdb"),
)
)

Expand All @@ -2775,7 +2775,7 @@ def extract_sql(self, expression: exp.Extract) -> str:
this="STRFTIME",
expressions=[strftime_input, exp.Literal.string(fmt)],
),
exp.DataType.build(cast_type, dialect="duckdb"),
exp.DataType.from_str(cast_type, dialect="duckdb"),
)
)

Expand All @@ -2784,7 +2784,7 @@ def extract_sql(self, expression: exp.Extract) -> str:
result: exp.Expr = exp.Anonymous(this=func_name, expressions=[datetime_expr])
# EPOCH returns float, cast to BIGINT for integer result
if part_name == "EPOCH_SECOND":
result = exp.cast(result, exp.DataType.build("BIGINT", dialect="duckdb"))
result = exp.cast(result, exp.DataType.from_str("BIGINT", dialect="duckdb"))
return self.sql(result)

return super().extract_sql(expression)
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/generators/singlestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class SingleStoreGenerator(MySQLGenerator):
lambda self, e: f"{self.sql(e, 'this')} !:> {self.sql(e, 'to')}"
),
exp.CastToStrType: lambda self, e: self.sql(
exp.cast(e.this, DataType.build(e.args["to"].name))
exp.cast(e.this, DataType.from_str(e.args["to"].name))
),
exp.StrToUnix: unsupported_args("format")(rename_func("UNIX_TIMESTAMP")),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _set_type(
# setter to enforce the getter's return type (Optional[DataType]), rejecting DType.
# Bypass by converting and assigning to _type directly.
dtype = target_type or exp.DType.UNKNOWN
expression._type = dtype if isinstance(dtype, exp.DataType) else exp.DataType.build(dtype)
expression._type = dtype if isinstance(dtype, exp.DataType) else dtype.into_expr()
self._visited.add(expression_id)

if (
Expand Down
8 changes: 4 additions & 4 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6050,7 +6050,7 @@ def _parse_user_defined_type(self, identifier: exp.Identifier) -> exp.Expr | Non
while self._match(TokenType.DOT):
type_name = f"{type_name}.{self._advance_any() and self._prev.text}"

return exp.DataType.build(type_name, dialect=self.dialect, udt=True)
return exp.DataType.from_str(type_name, dialect=self.dialect, udt=True)

def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
Expand All @@ -6073,7 +6073,7 @@ def _parse_types(

if tokens and (type_token := tokens[0].token_type) in self.TYPE_TOKENS:
if len(tokens) > 1:
return exp.DataType.build(identifier.name, dialect=self.dialect)
return exp.DataType.from_str(identifier.name, dialect=self.dialect)
elif self.dialect.SUPPORTS_USER_DEFINED_TYPES:
this = self._parse_user_defined_type(identifier)
else:
Expand Down Expand Up @@ -6324,7 +6324,7 @@ def _parse_json_type_arg(self) -> exp.Expr | None:
return self.expression(exp.ColumnDef(this=col, kind=kind))

def _parse_vector_expressions(self, expressions: list[exp.Expr]) -> list[exp.Expr]:
return [exp.DataType.build(expressions[0].name, dialect=self.dialect), *expressions[1:]]
return [exp.DataType.from_str(expressions[0].name, dialect=self.dialect), *expressions[1:]]

def _parse_struct_types(self, type_required: bool = False) -> exp.Expr | None:
index = self._index
Expand Down Expand Up @@ -7622,7 +7622,7 @@ def _parse_cast(self, strict: bool, safe: bool | None = None) -> exp.Expr:
elif not to:
self.raise_error("Expected TYPE after CAST")
elif isinstance(to, exp.Identifier):
to = exp.DataType.build(to.name, dialect=self.dialect, udt=True)
to = exp.DataType.from_str(to.name, dialect=self.dialect, udt=True)
elif to.this == exp.DType.CHAR and self._match(TokenType.CHARACTER_SET):
to = exp.DType.CHARACTER_SET.into_expr(kind=self._parse_var_or_string())

Expand Down
4 changes: 2 additions & 2 deletions sqlglot/parsers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def _builder(args: list) -> exp.Func:
# format strings (e.g., TO_TIMESTAMP('20240115', 'YYYYMMDD')) should
# use StrToTime, not UnixToTime.
unix_expr = exp.UnixToTime(this=value, scale=scale_or_fmt)
unix_expr.set("target_type", exp.DataType.build(kind, dialect="snowflake"))
unix_expr.set("target_type", kind.into_expr())
return unix_expr
if scale_or_fmt and not int_scale_or_fmt:
# Format string provided (e.g., 'YYYY-MM-DD'), use StrToTime
strtotime_expr = build_formatted_time(exp.StrToTime, "snowflake")(args)
strtotime_expr.set("safe", safe)
strtotime_expr.set("target_type", exp.DataType.build(kind, dialect="snowflake"))
strtotime_expr.set("target_type", kind.into_expr())
return strtotime_expr

# Handle DATE/TIME with format strings - allow int_value if a format string is provided
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/parsers/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def build_as_cast(to_type: str) -> t.Callable[[list], exp.Expr]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.from_str(to_type))


class Spark2Parser(HiveParser):
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.Da
udt = dialect.SUPPORTS_USER_DEFINED_TYPES

try:
expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
expression = exp.DataType.from_str(schema_type, dialect=dialect, udt=udt)
expression.transform(dialect.normalize_identifier, copy=False)
self._type_mapping_cache[schema_type] = expression
except AttributeError:
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,10 @@
"annotator": lambda self, e: self._annotate_by_args(e, "start", "end", "step", array=True)
},
exp.GenerateDateArray: {
"annotator": lambda self, e: self._set_type(e, exp.DataType.build("ARRAY<DATE>"))
"annotator": lambda self, e: self._set_type(e, exp.DataType.from_str("ARRAY<DATE>"))
},
exp.GenerateTimestampArray: {
"annotator": lambda self, e: self._set_type(e, exp.DataType.build("ARRAY<TIMESTAMP>"))
"annotator": lambda self, e: self._set_type(e, exp.DataType.from_str("ARRAY<TIMESTAMP>"))
},
exp.If: {"annotator": lambda self, e: self._annotate_by_args(e, "true", "false")},
exp.Literal: {"annotator": lambda self, e: self._annotate_literal(e)},
Expand Down
8 changes: 4 additions & 4 deletions sqlglot/typing/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array:
exp.DateFromUnixDate: {"returns": exp.DType.DATE},
exp.GenerateTimestampArray: {
"annotator": lambda self, e: self._set_type(
e, exp.DataType.build("ARRAY<TIMESTAMP>", dialect="bigquery")
e, exp.DataType.from_str("ARRAY<TIMESTAMP>", dialect="bigquery")
)
},
exp.JSONFormat: {
Expand All @@ -338,12 +338,12 @@ def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array:
},
exp.JSONKeysAtDepth: {
"annotator": lambda self, e: self._set_type(
e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery")
e, exp.DataType.from_str("ARRAY<VARCHAR>", dialect="bigquery")
)
},
exp.JSONValueArray: {
"annotator": lambda self, e: self._set_type(
e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery")
e, exp.DataType.from_str("ARRAY<VARCHAR>", dialect="bigquery")
)
},
exp.Lag: {"annotator": lambda self, e: self._annotate_by_args(e, "this", "default")},
Expand All @@ -352,7 +352,7 @@ def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array:
exp.SafeDivide: {"annotator": lambda self, e: _annotate_safe_divide(self, e)},
exp.ToCodePoints: {
"annotator": lambda self, e: self._set_type(
e, exp.DataType.build("ARRAY<BIGINT>", dialect="bigquery")
e, exp.DataType.from_str("ARRAY<BIGINT>", dialect="bigquery")
)
},
}
2 changes: 1 addition & 1 deletion sqlglot/typing/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@
},
exp.If: {"annotator": lambda self, e: self._annotate_by_args(e, "true", "false", promote=True)},
exp.Quantile: {"annotator": lambda self, e: self._annotate_by_args(e, "quantile")},
exp.RegexpSplit: {"returns": exp.DataType.build("ARRAY<STRING>")},
exp.RegexpSplit: {"returns": exp.DataType.from_str("ARRAY<STRING>")},
}
18 changes: 11 additions & 7 deletions sqlglot/typing/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def _annotate_median(self: TypeAnnotator, expression: exp.Median) -> exp.Median:
new_scale = min(scale + 3, MAX_SCALE)

# Build the new NUMBER type
new_type = exp.DataType.build(f"NUMBER({new_precision}, {new_scale})", dialect="snowflake")
new_type = exp.DataType.from_str(
f"NUMBER({new_precision}, {new_scale})", dialect="snowflake"
)
self._set_type(expression, new_type)

return expression
Expand All @@ -161,7 +163,7 @@ def _annotate_variance(self: TypeAnnotator, expression: exp.Expr) -> exp.Expr:

# Special case: DECFLOAT -> DECFLOAT(38)
if input_type.is_type(exp.DType.DECFLOAT):
self._set_type(expression, exp.DataType.build("DECFLOAT", dialect="snowflake"))
self._set_type(expression, exp.DataType.from_str("DECFLOAT", dialect="snowflake"))
# Special case: FLOAT/DOUBLE -> DOUBLE
elif input_type.is_type(exp.DType.FLOAT, exp.DType.DOUBLE):
self._set_type(expression, exp.DType.DOUBLE)
Expand All @@ -176,7 +178,9 @@ def _annotate_variance(self: TypeAnnotator, expression: exp.Expr) -> exp.Expr:
new_scale = 6 if scale == 0 else max(12, scale)

# Build the new NUMBER type
new_type = exp.DataType.build(f"NUMBER({MAX_PRECISION}, {new_scale})", dialect="snowflake")
new_type = exp.DataType.from_str(
f"NUMBER({MAX_PRECISION}, {new_scale})", dialect="snowflake"
)
self._set_type(expression, new_type)

return expression
Expand All @@ -194,12 +198,12 @@ def _annotate_kurtosis(self: TypeAnnotator, expression: exp.Kurtosis) -> exp.Kur
input_type = expression.this.type

if input_type.is_type(exp.DType.DECFLOAT):
self._set_type(expression, exp.DataType.build("DECFLOAT", dialect="snowflake"))
self._set_type(expression, exp.DataType.from_str("DECFLOAT", dialect="snowflake"))
elif input_type.is_type(exp.DType.FLOAT, exp.DType.DOUBLE):
self._set_type(expression, exp.DType.DOUBLE)
else:
self._set_type(
expression, exp.DataType.build(f"NUMBER({MAX_PRECISION}, 12)", dialect="snowflake")
expression, exp.DataType.from_str(f"NUMBER({MAX_PRECISION}, 12)", dialect="snowflake")
)

return expression
Expand Down Expand Up @@ -339,7 +343,7 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp
**{
expr_type: {
"annotator": lambda self, e: self._set_type(
e, exp.DataType.build("NUMBER", dialect="snowflake")
e, exp.DataType.from_str("NUMBER", dialect="snowflake")
)
}
for expr_type in (
Expand Down Expand Up @@ -547,7 +551,7 @@ def _annotate_str_to_time(self: TypeAnnotator, expression: exp.StrToTime) -> exp
exp.DecodeCase: {"annotator": _annotate_decode_case},
exp.HashAgg: {
"annotator": lambda self, e: self._set_type(
e, exp.DataType.build("NUMBER(19, 0)", dialect="snowflake")
e, exp.DataType.from_str("NUMBER(19, 0)", dialect="snowflake")
)
},
exp.Median: {"annotator": _annotate_median},
Expand Down
Loading