Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 Repl.it
Copyright (c) 2024 Replit

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
154 changes: 70 additions & 84 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from textwrap import dedent, indent
from typing import (
Any,
Dict,
List,
Literal,
Expand Down Expand Up @@ -180,15 +179,7 @@ class RiverIntersectionType(BaseModel):
allOf: List["RiverType"]


class RiverNotType(BaseModel):
"""This is used to represent void / never."""

not_: Any = Field(..., alias="not")


RiverType = Union[
RiverConcreteType, RiverUnionType, RiverNotType, RiverIntersectionType
]
RiverType = Union[RiverConcreteType, RiverUnionType, RiverIntersectionType]


class RiverProcedure(BaseModel):
Expand Down Expand Up @@ -239,8 +230,6 @@ def encode_type(
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
encoder_name: Optional[str] = None # defining this up here to placate mypy
chunks: List[FileContents] = []
if isinstance(type, RiverNotType):
return (TypeName("None"), [], [], set())
if isinstance(type, RiverUnionType):
typeddict_encoder = list[str]()
encoder_names: set[TypeName] = set()
Expand Down Expand Up @@ -352,7 +341,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
else:
local_discriminator = "FIXME: Ambiguous discriminators"
typeddict_encoder.append(
f" if '{local_discriminator}' in x else "
f" if {repr(local_discriminator)} in x else "
)
typeddict_encoder.pop() # Drop the last ternary
typeddict_encoder.append(")")
Expand All @@ -372,8 +361,8 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
typeddict_encoder.append(f"{encoder_name}(x)")
typeddict_encoder.append(
f"""
if x['{discriminator_name}']
== '{discriminator_value}'
if x[{repr(discriminator_name)}]
== {repr(discriminator_value)}
else
""",
)
Expand All @@ -393,7 +382,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
[
dedent(
f"""\
{encoder_name}: Callable[['{prefix}'], Any] = (
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
lambda x:
""".rstrip()
)
Expand Down Expand Up @@ -450,14 +439,17 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
chunks.append(
FileContents(
"\n".join(
[f"{encoder_name}: Callable[['{prefix}'], Any] = (lambda x: "]
[
f"{encoder_name}: Callable[[{repr(prefix)}], Any] = ("
"lambda x: "
]
+ typeddict_encoder
+ [")"]
)
)
)
return (prefix, in_module, chunks, encoder_names)
if isinstance(type, RiverIntersectionType):
elif isinstance(type, RiverIntersectionType):

def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
if isinstance(tpe, RiverUnionType):
Expand All @@ -478,15 +470,17 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
base_model,
in_module,
)
if isinstance(type, RiverConcreteType):
elif isinstance(type, RiverConcreteType):
typeddict_encoder = list[str]()
if type.type is None:
# Handle the case where type is not specified
typeddict_encoder.append("x")
return (TypeName("Any"), [], [], set())
elif type.type == "not":
return (TypeName("None"), [], [], set())
elif type.type == "string":
if type.const:
typeddict_encoder.append(f"'{type.const}'")
typeddict_encoder.append(repr(type.const))
return (LiteralTypeExpr(type.const), [], [], set())
else:
typeddict_encoder.append("x")
Expand Down Expand Up @@ -565,48 +559,48 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
name,
prop,
) in sorted(list(type.properties.items()), key=lambda xs: xs[0]):
typeddict_encoder.append(f"'{name}':")
typeddict_encoder.append(f"{repr(name)}:")
type_name, _, contents, _ = encode_type(
prop, TypeName(prefix + name.title()), base_model, in_module
)
encoder_name = None
chunks.extend(contents)
if base_model == "TypedDict":
if isinstance(prop, RiverNotType):
typeddict_encoder.append("'not implemented'")
elif isinstance(prop, RiverUnionType):
if isinstance(prop, RiverUnionType):
encoder_name = TypeName(
f"encode_{ensure_literal_type(type_name)}"
)
encoder_names.add(encoder_name)
typeddict_encoder.append(f"{encoder_name}(x['{name}'])")
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
if name not in type.required:
typeddict_encoder.append(f"if x['{name}'] else None")
typeddict_encoder.append(f"if x[{repr(name)}] else None")
elif isinstance(prop, RiverIntersectionType):
encoder_name = TypeName(
f"encode_{ensure_literal_type(type_name)}"
)
encoder_names.add(encoder_name)
typeddict_encoder.append(f"{encoder_name}(x['{name}'])")
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
elif isinstance(prop, RiverConcreteType):
if name == "$kind":
safe_name = "kind"
else:
safe_name = name
if prop.type == "object" and not prop.patternProperties:
if prop.type == "not":
typeddict_encoder.append("'not implemented'")
elif prop.type == "object" and not prop.patternProperties:
encoder_name = TypeName(
f"encode_{ensure_literal_type(type_name)}"
)
encoder_names.add(encoder_name)
typeddict_encoder.append(
f"{encoder_name}(x['{safe_name}'])"
f"{encoder_name}(x[{repr(safe_name)}])"
)
if name not in prop.required:
typeddict_encoder.append(
dedent(
f"""
if '{safe_name}' in x
and x['{safe_name}'] is not None
if {repr(safe_name)} in x
and x[{repr(safe_name)}] is not None
else None
"""
)
Expand All @@ -615,7 +609,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
items = cast(RiverConcreteType, prop).items
assert items, "Somehow items was none"
if is_literal(cast(RiverType, items)):
typeddict_encoder.append(f"x['{name}']")
typeddict_encoder.append(f"x[{repr(name)}]")
else:
match type_name:
case ListTypeExpr(inner_type_name):
Expand All @@ -628,16 +622,16 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
f"""\
[
{encoder_name}(y)
for y in x['{name}']
for y in x[{repr(name)}]
]
""".rstrip()
)
)
else:
if name in prop.required:
typeddict_encoder.append(f"x['{safe_name}']")
typeddict_encoder.append(f"x[{repr(safe_name)}]")
else:
typeddict_encoder.append(f"x.get('{safe_name}')")
typeddict_encoder.append(f"x.get({repr(safe_name)})")

if name == "$kind":
# If the field is a literal, the Python type-checker will complain
Expand All @@ -657,7 +651,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
f"""\
= Field(
default=None,
alias='{name}', # type: ignore
alias={repr(name)}, # type: ignore
)
"""
)
Expand All @@ -671,7 +665,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
f"""\
= Field(
{field_value},
alias='{name}', # type: ignore
alias={repr(name)}, # type: ignore
)
"""
)
Expand Down Expand Up @@ -714,7 +708,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
[
dedent(
f"""\
{encoder_name}: Callable[['{prefix}'], Any] = (
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
lambda {binding}:
"""
)
Expand Down Expand Up @@ -781,7 +775,7 @@ def __init__(self, client: river.Client[Any]):
module_names = [ModuleName(name)]
init_type: Optional[TypeExpression] = None
if procedure.init:
init_type, module_info, input_chunks, encoder_names = encode_type(
init_type, module_info, init_chunks, encoder_names = encode_type(
procedure.init,
TypeName(f"{name.title()}Init"),
input_base_class,
Expand All @@ -791,7 +785,7 @@ def __init__(self, client: river.Client[Any]):
(
[extract_inner_type(init_type), *encoder_names],
module_info,
input_chunks,
init_chunks,
)
)
input_type, module_info, input_chunks, encoder_names = encode_type(
Expand Down Expand Up @@ -856,31 +850,28 @@ def __init__(self, client: river.Client[Any]):

# Init renderer
render_init_method: Optional[str] = None
if input_base_class == "TypedDict" and init_type:
if is_literal(procedure.input):
render_init_method = "lambda x: x"
elif isinstance(
procedure.input, RiverConcreteType
) and procedure.input.type in ["array"]:
match init_type:
case ListTypeExpr(init_type_name):
render_init_method = (
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
)
if init_type and procedure.init is not None:
if input_base_class == "TypedDict":
if is_literal(procedure.init):
render_init_method = "lambda x: x"
elif isinstance(
procedure.init, RiverConcreteType
) and procedure.init.type in ["array"]:
match init_type:
case ListTypeExpr(init_type_name):
render_init_method = (
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
)
else:
render_init_method = f"encode_{ensure_literal_type(init_type)}"
else:
render_init_method = f"encode_{ensure_literal_type(init_type)}"
else:
render_init_method = f"""\
lambda x: TypeAdapter({render_type_expr(input_type)})
.validate_python
"""
if isinstance(
procedure.init, RiverConcreteType
) and procedure.init.type not in ["object", "array"]:
render_init_method = "lambda x: x"
render_init_method = f"""\
lambda x: TypeAdapter({render_type_expr(init_type)})
.validate_python
"""

assert (
render_init_method
init_type is None or render_init_method
), f"Unable to derive the init encoder from: {input_type}"

# Input renderer
Expand Down Expand Up @@ -922,10 +913,6 @@ def __init__(self, client: river.Client[Any]):
parse_output_method = "lambda x: None"

if procedure.type == "rpc":
control_flow_keyword = "return "
if output_type == "None":
control_flow_keyword = ""

current_chunks.extend(
[
reindent(
Expand All @@ -935,9 +922,9 @@ async def {name}(
self,
input: {render_type_expr(input_type)},
) -> {render_type_expr(output_type)}:
{control_flow_keyword}await self.client.send_rpc(
'{schema_name}',
'{name}',
return await self.client.send_rpc(
{repr(schema_name)},
{repr(name)},
input,
{reindent(" ", render_input_method)},
{reindent(" ", parse_output_method)},
Expand All @@ -958,8 +945,8 @@ async def {name}(
input: {render_type_expr(input_type)},
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
return await self.client.send_subscription(
'{schema_name}',
'{name}',
{repr(schema_name)},
{repr(name)},
input,
{reindent(" ", render_input_method)},
{reindent(" ", parse_output_method)},
Expand All @@ -970,10 +957,8 @@ async def {name}(
]
)
elif procedure.type == "upload":
control_flow_keyword = "return "
if output_type == "None":
control_flow_keyword = ""
if init_type:
assert render_init_method, "Expected an init renderer!"
current_chunks.extend(
[
reindent(
Expand All @@ -984,9 +969,9 @@ async def {name}(
init: {init_type},
inputStream: AsyncIterable[{render_type_expr(input_type)}],
) -> {output_type}:
{control_flow_keyword}await self.client.send_upload(
'{schema_name}',
'{name}',
return await self.client.send_upload(
{repr(schema_name)},
{repr(name)},
init,
inputStream,
{reindent(" ", render_init_method)},
Expand All @@ -1008,9 +993,9 @@ async def {name}(
self,
inputStream: AsyncIterable[{render_type_expr(input_type)}],
) -> {render_type_expr(output_or_error_type)}:
{control_flow_keyword}await self.client.send_upload(
'{schema_name}',
'{name}',
return await self.client.send_upload(
{repr(schema_name)},
{repr(name)},
None,
inputStream,
None,
Expand All @@ -1024,6 +1009,7 @@ async def {name}(
)
elif procedure.type == "stream":
if init_type:
assert render_init_method, "Expected an init renderer!"
current_chunks.extend(
[
reindent(
Expand All @@ -1035,8 +1021,8 @@ async def {name}(
inputStream: AsyncIterable[{render_type_expr(input_type)}],
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
return await self.client.send_stream(
'{schema_name}',
'{name}',
{repr(schema_name)},
{repr(name)},
init,
inputStream,
{reindent(" ", render_init_method)},
Expand All @@ -1059,8 +1045,8 @@ async def {name}(
inputStream: AsyncIterable[{render_type_expr(input_type)}],
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
return await self.client.send_stream(
'{schema_name}',
'{name}',
{repr(schema_name)},
{repr(name)},
None,
inputStream,
None,
Expand Down
Loading