Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 167 additions & 1 deletion gel/_internal/_codegen/_models/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,14 @@ def process(self, mod: IntrospectedModule) -> None:
if op.schemapath.parent == self.canonical_modpath
]
)
self.write_non_magic_ternary_operators(
[
op
for op in self._operators.other_ops
if op.schemapath.parent == self.canonical_modpath
if op.operator_kind == reflection.OperatorKind.Ternary
]
)
self.write_globals(mod["globals"])

def reexport_module(self, mod: GeneratedSchemaModule) -> None:
Expand Down Expand Up @@ -3439,6 +3447,71 @@ def param_getter(
excluded_param_types=excluded_param_types,
)

def write_non_magic_ternary_operators(
self,
ops: list[reflection.Operator],
) -> bool:
# Filter to unary operators without Python magic method equivalents
ternary_ops = [op for op in ops if op.py_magic is None]
if not ternary_ops:
return False
else:
self._write_callables(
ternary_ops,
style="function",
type_ignore=("override", "unused-ignore"),
node_ctor=self._write_ternary_op_func_node_ctor,
)
return True

def _write_ternary_op_func_node_ctor(
self,
op: reflection.Operator,
) -> None:
"""Generate the query node constructor for a ternary operator function.

Creates the code that builds a TernaryOp query node for ternary
operator functions. Unlike method versions, this takes the operand from
function arguments and applies special type casting for tuple
parameters.

Args:
op: The operator reflection object containing metadata
"""
node_cls = self.import_name(BASE_IMPL, "TernaryOp")
expr_compat = self.import_name(BASE_IMPL, "ExprCompatible")
cast_ = self.import_name("typing", "cast")

op_1: str
op_2: str
if op.schemapath == SchemaPath("std", "IF"):
op_1 = op.schemapath.name
op_2 = '"ELSE"'

else:
raise NotImplementedError(f"Unknown operator {op.schemapath}")

if_true = "__args__[0]"
condition = "__args__[1]"
if_false = "__args__[2]"
# Tuple parameters need ExprCompatible casting
# due to a possible mypy bug.
if reflection.is_tuple_type(op.params[0].get_type(self._types)):
if_true = f"{cast_}({expr_compat!r}, {if_true})"
if reflection.is_tuple_type(op.params[2].get_type(self._types)):
if_false = f"{cast_}({expr_compat!r}, {if_false})"

args = [
f"lexpr={if_true}",
f'op_1="{op_1}"', # Gel operator name (e.g., "IF")
f"mexpr={condition}",
f"op_2={op_2}",
f"rexpr={if_false}",
"type_=__rtype__.__gel_reflection__.type_name", # Result type info
]

self.write(self.format_list(f"{node_cls}({{list}}),", args))

def _partition_nominal_overloads(
self,
callables: Iterable[_Callable_T],
Expand Down Expand Up @@ -3649,16 +3722,67 @@ def _write_potentially_overlapping_overloads(
# SEE ABOVE: This is what we actually want.
# key=lambda o: (generality_key(o), o.edgeql_signature), # noqa: ERA001, E501
)
base_generic_overload: dict[_Callable_T, _Callable_T] = {}

for overload in overloads:
overload_signatures[overload] = {}

if overload.schemapath == SchemaPath('std', 'IF'):
# HACK: Pretend the base overload of std::IF is generic on
# anyobject.
#
# The base overload of std::IF is
# (anytype, std::bool, anytype) -> anytype
#
# However, this causes an overlap with overloading for bool
# arguments since
# (anytype, builtin.bool, anytype) -> anytype
# overlaps with
# (std::bool, builtin.bool, std::bool) -> std::bool
#
# We resolve this by generating the specializations for anytype
# but using anyobject as the base generic type.

def anytype_to_anyobject(
refl_type: reflection.Type,
default: reflection.Type | reflection.TypeRef,
) -> reflection.Type | reflection.TypeRef:
if isinstance(refl_type, reflection.PseudoType):
return self._types_by_name["anyobject"]
return default

base_generic_overload[overload] = dataclasses.replace(
overload,
params=[
dataclasses.replace(
param,
type=anytype_to_anyobject(
param.get_type(self._types), param.type
),
)
for param in overload.params
],
return_type=anytype_to_anyobject(
overload.get_return_type(self._types),
overload.return_type,
),
)

for param in param_getter(overload):
param_overload_map[param.key].add(overload)
param_type = param.get_type(self._types)
# Unwrap the variadic type (it is reflected as an array of T)
if param.kind is reflection.CallableParamKind.Variadic:
if reflection.is_array_type(param_type):
param_type = param_type.get_element_type(self._types)

if (
overload.schemapath == SchemaPath('std', 'IF')
and param_type.is_pseudo
):
# Also generate the base signature using anyobject
param_type = self._types_by_name["anyobject"]

# Start with the base parameter type
overload_signatures[overload][param.key] = [param_type]

Expand Down Expand Up @@ -3770,7 +3894,10 @@ def specialization_sort_key(t: reflection.Type) -> int:
for overload in overloads:
if overload_specs := overloads_specializations.get(overload):
expanded_overloads.extend(overload_specs)
expanded_overloads.append(overload)
if overload in base_generic_overload:
expanded_overloads.append(base_generic_overload[overload])
else:
expanded_overloads.append(overload)
overloads = expanded_overloads

overload_order = {overload: i for i, overload in enumerate(overloads)}
Expand Down Expand Up @@ -6170,6 +6297,45 @@ def resolve(
f"# type: ignore [assignment, misc, unused-ignore]"
)

if function.schemapath in {
SchemaPath('std', 'UNION'),
SchemaPath('std', 'IF'),
SchemaPath('std', '??'),
}:
# Special case for the UNION, IF and ?? operators
# Produce a union type instead of just taking the first
# valid type.
#
# See gel: edb.compiler.func.compile_operator
create_union = self.import_name(
BASE_IMPL, "create_optional_union"
)

tvars: list[str] = []
for param, path in sources:
if (
param.name in required_generic_params
or param.name in optional_generic_params
):
pn = param_vars[param.name]
tvar = f"__t_{pn}__"

resolve(pn, path, tvar)
tvars.append(tvar)

self.write(
f"{gtvar} = {tvars[0]} "
f"# type: ignore [assignment, misc, unused-ignore]"
)
for tvar in tvars[1:]:
self.write(
f"{gtvar} = {create_union}({gtvar}, {tvar}) "
f"# type: ignore ["
f"assignment, misc, unused-ignore]"
)

continue

# Try to infer generic type from required params first
for param, path in sources:
if param.name in required_generic_params:
Expand Down
2 changes: 2 additions & 0 deletions gel/_internal/_qb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
ShapeOp,
Splat,
StringLiteral,
TernaryOp,
UnaryOp,
UpdateStmt,
Variable,
Expand Down Expand Up @@ -171,6 +172,7 @@
"Splat",
"Stmt",
"StringLiteral",
"TernaryOp",
"UnaryOp",
"UpdateStmt",
"VarAlias",
Expand Down
51 changes: 51 additions & 0 deletions gel/_internal/_qb/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,57 @@ def __edgeql_expr__(self, *, ctx: ScopeContext) -> str:
return f"{left}[{right}]"


@dataclass(kw_only=True, frozen=True)
class TernaryOp(TypedExpr):
lexpr: Expr
op_1: _edgeql.Token
mexpr: Expr
op_2: _edgeql.Token
rexpr: Expr

def __init__(
self,
*,
lexpr: ExprCompatible,
op_1: _edgeql.Token | str,
mexpr: ExprCompatible,
op_2: _edgeql.Token | str,
rexpr: ExprCompatible,
type_: TypeNameExpr,
) -> None:
object.__setattr__(self, "lexpr", edgeql_qb_expr(lexpr))
if not isinstance(op_1, _edgeql.Token):
op_1 = _edgeql.Token.from_str(op_1)
object.__setattr__(self, "op_1", op_1)
object.__setattr__(self, "mexpr", edgeql_qb_expr(mexpr))
if not isinstance(op_2, _edgeql.Token):
op_2 = _edgeql.Token.from_str(op_2)
object.__setattr__(self, "op_2", op_2)
object.__setattr__(self, "rexpr", edgeql_qb_expr(rexpr))
super().__init__(type_=type_)

def subnodes(self) -> Iterable[Node]:
return (self.lexpr, self.mexpr, self.rexpr)

def __edgeql_expr__(self, *, ctx: ScopeContext) -> str:
left = edgeql(self.lexpr, ctx=ctx)
if self.lexpr.precedence <= self.precedence:
left = f"({left})"
middle = edgeql(self.mexpr, ctx=ctx)
if self.mexpr.precedence <= self.precedence:
middle = f"({middle})"
right = edgeql(self.rexpr, ctx=ctx)
if self.mexpr.precedence <= self.precedence:
right = f"({right})"
return f"{left} {self.op_1} {middle} {self.op_2} {right}"

@property
def precedence(self) -> _edgeql.Precedence:
return max(
_edgeql.PRECEDENCE[self.op_1], _edgeql.PRECEDENCE[self.op_2]
)


@dataclass(kw_only=True, frozen=True)
class FuncCall(TypedExpr):
fname: str
Expand Down
6 changes: 6 additions & 0 deletions gel/_internal/_qbmodel/_abstract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
from ._methods import (
BaseGelModel,
BaseGelModelIntersection,
BaseGelModelUnion,
create_optional_union,
create_union,
)


Expand Down Expand Up @@ -138,6 +141,7 @@
"ArrayMeta",
"BaseGelModel",
"BaseGelModelIntersection",
"BaseGelModelUnion",
"ComputedLinkSet",
"ComputedLinkWithPropsSet",
"ComputedMultiLinkDescriptor",
Expand Down Expand Up @@ -181,6 +185,8 @@
"TupleMeta",
"UUIDImpl",
"copy_or_ref_lprops",
"create_optional_union",
"create_union",
"empty_set_if_none",
"field_descriptor",
"get_base_scalars_backed_by_py_type",
Expand Down
Loading