Skip to content
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.11
3.12
4 changes: 2 additions & 2 deletions .replit
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
run = "poetry run pytest tests"

modules = ["python-3.11"]
modules = ["python-3.12"]

[nix]
channel = "stable-23_11"
channel = "stable-24_11"
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib";
};
packages = replitNixDeps ++ [
pkgs.python311
pkgs.python312
pkgs.uv
];
};
Expand Down
4 changes: 2 additions & 2 deletions src/replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def send_subscription(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
with _trace_procedure(
"subscription", service_name, procedure_name
) as span_handle:
Expand Down Expand Up @@ -157,7 +157,7 @@ async def send_stream(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
session = await self._transport.get_or_create_session()
async for msg in session.send_stream(
Expand Down
95 changes: 80 additions & 15 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river

"""
Expand Down Expand Up @@ -763,6 +763,27 @@ def generate_individual_service(
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []

def _type_adapter_definition(
type_adapter_name: TypeName,
_type: TypeExpression,
module_info: list[ModuleName],
) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]:
rendered_type_expr = render_type_expr(_type)
return (
[type_adapter_name],
module_info,
[
FileContents(
dedent(f"""
{render_type_expr(type_adapter_name)}: TypeAdapter[Any] = (
TypeAdapter({rendered_type_expr})
)
""")
)
],
)

class_name = ClassName(f"{schema_name.title()}Service")
current_chunks: List[str] = [
dedent(
Expand Down Expand Up @@ -798,27 +819,46 @@ def __init__(self, client: river.Client[Any]):
module_names,
permit_unknown_members=False,
)
input_type_name = extract_inner_type(input_type)
input_type_type_adapter_name = TypeName(
f"{render_literal_type(input_type_name)}TypeAdapter"
)
serdes.append(
(
[extract_inner_type(input_type), *encoder_names],
module_info,
input_chunks,
)
)
serdes.append(
_type_adapter_definition(
input_type_type_adapter_name, input_type, module_info
)
)
output_type, module_info, output_chunks, encoder_names = encode_type(
procedure.output,
TypeName(f"{name.title()}Output"),
"BaseModel",
module_names,
permit_unknown_members=True,
)
output_type_name = extract_inner_type(output_type)
serdes.append(
(
[extract_inner_type(output_type), *encoder_names],
[output_type_name, *encoder_names],
module_info,
output_chunks,
)
)
output_type_type_adapter_name = TypeName(
f"{render_literal_type(output_type_name)}TypeAdapter"
)
serdes.append(
_type_adapter_definition(
output_type_type_adapter_name, output_type, module_info
)
)
output_module_info = module_info
if procedure.errors:
error_type, module_info, errors_chunks, encoder_names = encode_type(
procedure.errors,
Expand All @@ -828,27 +868,43 @@ def __init__(self, client: river.Client[Any]):
permit_unknown_members=True,
)
if isinstance(error_type, NoneTypeExpr):
error_type = TypeName("RiverError")
error_type_name = TypeName("RiverError")
error_type = error_type_name
else:
serdes.append(
([extract_inner_type(error_type)], module_info, errors_chunks)
)
error_type_name = extract_inner_type(error_type)
serdes.append(([error_type_name], module_info, errors_chunks))

else:
error_type = TypeName("RiverError")
output_or_error_type = UnionTypeExpr([output_type, error_type])
error_type_name = TypeName("RiverError")
error_type = error_type_name

error_type_type_adapter_name = TypeName(
f"{render_literal_type(error_type_name)}TypeAdapter"
)
if error_type_type_adapter_name.value != "RiverErrorTypeAdapter":
if len(module_info) == 0:
module_info = output_module_info
serdes.append(
_type_adapter_definition(
error_type_type_adapter_name, error_type, module_info
)
)
output_or_error_type = UnionTypeExpr([output_type, error_type_name])

# NB: These strings must be indented to at least the same level of
# the function strings in the branches below, otherwise `dedent`
# will pick our indentation level for normalization, which will
# break the "def" indentation presuppositions.
output_type_adapter = render_literal_type(output_type_type_adapter_name)
parse_output_method = f"""\
lambda x: TypeAdapter({render_type_expr(output_type)})
lambda x: {output_type_adapter}
.validate_python(
x # type: ignore[arg-type]
)
"""
error_type_adapter = render_literal_type(error_type_type_adapter_name)
parse_error_method = f"""\
lambda x: TypeAdapter({render_type_expr(error_type)})
lambda x: {error_type_adapter}
.validate_python(
x # type: ignore[arg-type]
)
Expand All @@ -871,9 +927,18 @@ def __init__(self, client: river.Client[Any]):
else:
render_init_method = f"encode_{render_literal_type(init_type)}"
else:
init_type_name = extract_inner_type(init_type)
init_type_type_adapter_name = TypeName(
f"{init_type_name.value}TypeAdapter"
)
serdes.append(
_type_adapter_definition(
init_type_type_adapter_name, init_type, module_info
)
)
render_init_method = f"""\
lambda x: TypeAdapter({render_type_expr(init_type)})
.validate_python
lambda x: {render_type_expr(init_type_type_adapter_name)}
.validate_python
"""

assert init_type is None or render_init_method, (
Expand All @@ -889,17 +954,17 @@ def __init__(self, client: river.Client[Any]):
procedure.input, RiverConcreteType
) and procedure.input.type in ["array"]:
match input_type:
case ListTypeExpr(input_type_name):
case ListTypeExpr(list_type):
render_input_method = f"""\
lambda xs: [
encode_{render_literal_type(input_type_name)}(x) for x in xs
encode_{render_literal_type(list_type)}(x) for x in xs
]
"""
else:
render_input_method = f"encode_{render_literal_type(input_type)}"
else:
render_input_method = f"""\
lambda x: TypeAdapter({render_type_expr(input_type)})
lambda x: {render_type_expr(input_type_type_adapter_name)}
.dump_python(
x, # type: ignore[arg-type]
by_alias=True,
Expand Down
5 changes: 4 additions & 1 deletion src/replit_river/error_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter

ERROR_CODE_STREAM_CLOSED = "stream_closed"
ERROR_HANDSHAKE = "handshake_failed"
Expand All @@ -25,6 +25,9 @@ class RiverError(BaseModel):
message: str


RiverErrorTypeAdapter = TypeAdapter(RiverError)


class RiverException(Exception):
"""Exception raised by the River server."""

Expand Down
14 changes: 10 additions & 4 deletions tests/codegen/rpc/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput
from .rpc_method import (
Rpc_MethodInput,
Rpc_MethodInputTypeAdapter,
Rpc_MethodOutput,
Rpc_MethodOutputTypeAdapter,
encode_Rpc_MethodInput,
)


class Test_ServiceService:
Expand All @@ -26,10 +32,10 @@ async def rpc_method(
"rpc_method",
input,
encode_Rpc_MethodInput,
lambda x: TypeAdapter(Rpc_MethodOutput).validate_python(
lambda x: Rpc_MethodOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(RiverError).validate_python(
lambda x: RiverErrorTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
Expand Down
6 changes: 6 additions & 0 deletions tests/codegen/rpc/generated/test_service/rpc_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,11 @@ class Rpc_MethodInput(TypedDict):
data: str


Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput)


class Rpc_MethodOutput(BaseModel):
data: str


Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput)
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,26 @@

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


from .needsEnum import (
NeedsenumErrors,
NeedsenumErrorsTypeAdapter,
NeedsenumInput,
NeedsenumInputTypeAdapter,
NeedsenumOutput,
NeedsenumOutputTypeAdapter,
encode_NeedsenumInput,
)
from .needsEnumObject import (
NeedsenumobjectErrors,
NeedsenumobjectErrorsTypeAdapter,
NeedsenumobjectInput,
NeedsenumobjectInputTypeAdapter,
NeedsenumobjectOutput,
NeedsenumobjectOutputTypeAdapter,
encode_NeedsenumobjectInput,
)

Expand All @@ -37,10 +43,10 @@ async def needsEnum(
"needsEnum",
input,
lambda x: x,
lambda x: TypeAdapter(NeedsenumOutput).validate_python(
lambda x: NeedsenumOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(NeedsenumErrors).validate_python(
lambda x: NeedsenumErrorsTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
Expand All @@ -56,10 +62,10 @@ async def needsEnumObject(
"needsEnumObject",
input,
encode_NeedsenumobjectInput,
lambda x: TypeAdapter(NeedsenumobjectOutput).validate_python(
lambda x: NeedsenumobjectOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(NeedsenumobjectErrors).validate_python(
lambda x: NeedsenumobjectErrorsTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@

NeedsenumInput = Literal["in_first"] | Literal["in_second"]
encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x

NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput)

NeedsenumOutput = Annotated[
Literal["out_first"] | Literal["out_second"] | RiverUnknownValue,
WrapValidator(translate_unknown_value),
]

NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput)

NeedsenumErrors = Annotated[
Literal["err_first"] | Literal["err_second"] | RiverUnknownValue,
WrapValidator(translate_unknown_value),
]

NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors)
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict):
else encode_NeedsenumobjectInputOneOf_in_second(x)
)

NeedsenumobjectInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectInput)


class NeedsenumobjectOutputFooOneOf_out_first(BaseModel):
kind: Literal["out_first"] = Field(
Expand Down Expand Up @@ -103,6 +105,9 @@ class NeedsenumobjectOutput(BaseModel):
foo: Optional[NeedsenumobjectOutputFoo] = None


NeedsenumobjectOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectOutput)


class NeedsenumobjectErrorsFooAnyOf_0(RiverError):
beep: Optional[Literal["err_first"]] = None

Expand All @@ -121,3 +126,6 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError):

class NeedsenumobjectErrors(RiverError):
foo: Optional[NeedsenumobjectErrorsFoo] = None


NeedsenumobjectErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectErrors)
Loading
Loading