Skip to content

Commit 26ba67e

Browse files
Cache type adapters (#140)
Why === 1. Some py-spy tracing found a possible bottleneck with creating TypeAdapters. 2. Pydantic performance tips that says that [caching TypeAdapters is good](https://docs.pydantic.dev/latest/concepts/performance/#typeadapter-instantiated-once) Also replit/ai-infra#4672 What changed ============ Changed the river-python client to generate code that caches the TypeAdapters. ## Testing * tests pass * regenerating chat and pid2 clients and their tests still pass --------- Co-authored-by: Devon Stewart <devon.stewart@repl.it>
1 parent 64629dc commit 26ba67e

File tree

13 files changed

+144
-34
lines changed

13 files changed

+144
-34
lines changed

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.11
1+
3.12

.replit

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
run = "poetry run pytest tests"
22

3-
modules = ["python-3.11"]
3+
modules = ["python-3.12"]
44

55
[nix]
6-
channel = "stable-23_11"
6+
channel = "stable-24_11"

flake.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib";
1919
};
2020
packages = replitNixDeps ++ [
21-
pkgs.python311
21+
pkgs.python312
2222
pkgs.uv
2323
];
2424
};

src/replit_river/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ async def send_subscription(
129129
request_serializer: Callable[[RequestType], Any],
130130
response_deserializer: Callable[[Any], ResponseType],
131131
error_deserializer: Callable[[Any], ErrorType],
132-
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
132+
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
133133
with _trace_procedure(
134134
"subscription", service_name, procedure_name
135135
) as span_handle:
@@ -157,7 +157,7 @@ async def send_stream(
157157
request_serializer: Callable[[RequestType], Any],
158158
response_deserializer: Callable[[Any], ResponseType],
159159
error_deserializer: Callable[[Any], ErrorType],
160-
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
160+
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
161161
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
162162
session = await self._transport.get_or_create_session()
163163
async for msg in session.send_stream(

src/replit_river/codegen/client.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
6565
from pydantic import TypeAdapter
6666
67-
from replit_river.error_schema import RiverError
67+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
6868
import replit_river as river
6969
7070
"""
@@ -763,6 +763,27 @@ def generate_individual_service(
763763
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
764764
) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
765765
serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
766+
767+
def _type_adapter_definition(
768+
type_adapter_name: TypeName,
769+
_type: TypeExpression,
770+
module_info: list[ModuleName],
771+
) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]:
772+
rendered_type_expr = render_type_expr(_type)
773+
return (
774+
[type_adapter_name],
775+
module_info,
776+
[
777+
FileContents(
778+
dedent(f"""
779+
{render_type_expr(type_adapter_name)}: TypeAdapter[Any] = (
780+
TypeAdapter({rendered_type_expr})
781+
)
782+
""")
783+
)
784+
],
785+
)
786+
766787
class_name = ClassName(f"{schema_name.title()}Service")
767788
current_chunks: List[str] = [
768789
dedent(
@@ -798,27 +819,46 @@ def __init__(self, client: river.Client[Any]):
798819
module_names,
799820
permit_unknown_members=False,
800821
)
822+
input_type_name = extract_inner_type(input_type)
823+
input_type_type_adapter_name = TypeName(
824+
f"{render_literal_type(input_type_name)}TypeAdapter"
825+
)
801826
serdes.append(
802827
(
803828
[extract_inner_type(input_type), *encoder_names],
804829
module_info,
805830
input_chunks,
806831
)
807832
)
833+
serdes.append(
834+
_type_adapter_definition(
835+
input_type_type_adapter_name, input_type, module_info
836+
)
837+
)
808838
output_type, module_info, output_chunks, encoder_names = encode_type(
809839
procedure.output,
810840
TypeName(f"{name.title()}Output"),
811841
"BaseModel",
812842
module_names,
813843
permit_unknown_members=True,
814844
)
845+
output_type_name = extract_inner_type(output_type)
815846
serdes.append(
816847
(
817-
[extract_inner_type(output_type), *encoder_names],
848+
[output_type_name, *encoder_names],
818849
module_info,
819850
output_chunks,
820851
)
821852
)
853+
output_type_type_adapter_name = TypeName(
854+
f"{render_literal_type(output_type_name)}TypeAdapter"
855+
)
856+
serdes.append(
857+
_type_adapter_definition(
858+
output_type_type_adapter_name, output_type, module_info
859+
)
860+
)
861+
output_module_info = module_info
822862
if procedure.errors:
823863
error_type, module_info, errors_chunks, encoder_names = encode_type(
824864
procedure.errors,
@@ -828,27 +868,43 @@ def __init__(self, client: river.Client[Any]):
828868
permit_unknown_members=True,
829869
)
830870
if isinstance(error_type, NoneTypeExpr):
831-
error_type = TypeName("RiverError")
871+
error_type_name = TypeName("RiverError")
872+
error_type = error_type_name
832873
else:
833-
serdes.append(
834-
([extract_inner_type(error_type)], module_info, errors_chunks)
835-
)
874+
error_type_name = extract_inner_type(error_type)
875+
serdes.append(([error_type_name], module_info, errors_chunks))
876+
836877
else:
837-
error_type = TypeName("RiverError")
838-
output_or_error_type = UnionTypeExpr([output_type, error_type])
878+
error_type_name = TypeName("RiverError")
879+
error_type = error_type_name
880+
881+
error_type_type_adapter_name = TypeName(
882+
f"{render_literal_type(error_type_name)}TypeAdapter"
883+
)
884+
if error_type_type_adapter_name.value != "RiverErrorTypeAdapter":
885+
if len(module_info) == 0:
886+
module_info = output_module_info
887+
serdes.append(
888+
_type_adapter_definition(
889+
error_type_type_adapter_name, error_type, module_info
890+
)
891+
)
892+
output_or_error_type = UnionTypeExpr([output_type, error_type_name])
839893

840894
# NB: These strings must be indented to at least the same level of
841895
# the function strings in the branches below, otherwise `dedent`
842896
# will pick our indentation level for normalization, which will
843897
# break the "def" indentation presuppositions.
898+
output_type_adapter = render_literal_type(output_type_type_adapter_name)
844899
parse_output_method = f"""\
845-
lambda x: TypeAdapter({render_type_expr(output_type)})
900+
lambda x: {output_type_adapter}
846901
.validate_python(
847902
x # type: ignore[arg-type]
848903
)
849904
"""
905+
error_type_adapter = render_literal_type(error_type_type_adapter_name)
850906
parse_error_method = f"""\
851-
lambda x: TypeAdapter({render_type_expr(error_type)})
907+
lambda x: {error_type_adapter}
852908
.validate_python(
853909
x # type: ignore[arg-type]
854910
)
@@ -871,9 +927,18 @@ def __init__(self, client: river.Client[Any]):
871927
else:
872928
render_init_method = f"encode_{render_literal_type(init_type)}"
873929
else:
930+
init_type_name = extract_inner_type(init_type)
931+
init_type_type_adapter_name = TypeName(
932+
f"{init_type_name.value}TypeAdapter"
933+
)
934+
serdes.append(
935+
_type_adapter_definition(
936+
init_type_type_adapter_name, init_type, module_info
937+
)
938+
)
874939
render_init_method = f"""\
875-
lambda x: TypeAdapter({render_type_expr(init_type)})
876-
.validate_python
940+
lambda x: {render_type_expr(init_type_type_adapter_name)}
941+
.validate_python
877942
"""
878943

879944
assert init_type is None or render_init_method, (
@@ -889,17 +954,17 @@ def __init__(self, client: river.Client[Any]):
889954
procedure.input, RiverConcreteType
890955
) and procedure.input.type in ["array"]:
891956
match input_type:
892-
case ListTypeExpr(input_type_name):
957+
case ListTypeExpr(list_type):
893958
render_input_method = f"""\
894959
lambda xs: [
895-
encode_{render_literal_type(input_type_name)}(x) for x in xs
960+
encode_{render_literal_type(list_type)}(x) for x in xs
896961
]
897962
"""
898963
else:
899964
render_input_method = f"encode_{render_literal_type(input_type)}"
900965
else:
901966
render_input_method = f"""\
902-
lambda x: TypeAdapter({render_type_expr(input_type)})
967+
lambda x: {render_type_expr(input_type_type_adapter_name)}
903968
.dump_python(
904969
x, # type: ignore[arg-type]
905970
by_alias=True,

src/replit_river/error_schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, List, Optional
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, TypeAdapter
44

55
ERROR_CODE_STREAM_CLOSED = "stream_closed"
66
ERROR_HANDSHAKE = "handshake_failed"
@@ -25,6 +25,9 @@ class RiverError(BaseModel):
2525
message: str
2626

2727

28+
RiverErrorTypeAdapter = TypeAdapter(RiverError)
29+
30+
2831
class RiverException(Exception):
2932
"""Exception raised by the River server."""
3033

tests/codegen/rpc/generated/test_service/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55

66
from pydantic import TypeAdapter
77

8-
from replit_river.error_schema import RiverError
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
99
import replit_river as river
1010

1111

12-
from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput
12+
from .rpc_method import (
13+
Rpc_MethodInput,
14+
Rpc_MethodInputTypeAdapter,
15+
Rpc_MethodOutput,
16+
Rpc_MethodOutputTypeAdapter,
17+
encode_Rpc_MethodInput,
18+
)
1319

1420

1521
class Test_ServiceService:
@@ -26,10 +32,10 @@ async def rpc_method(
2632
"rpc_method",
2733
input,
2834
encode_Rpc_MethodInput,
29-
lambda x: TypeAdapter(Rpc_MethodOutput).validate_python(
35+
lambda x: Rpc_MethodOutputTypeAdapter.validate_python(
3036
x # type: ignore[arg-type]
3137
),
32-
lambda x: TypeAdapter(RiverError).validate_python(
38+
lambda x: RiverErrorTypeAdapter.validate_python(
3339
x # type: ignore[arg-type]
3440
),
3541
timeout,

tests/codegen/rpc/generated/test_service/rpc_method.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,11 @@ class Rpc_MethodInput(TypedDict):
3939
data: str
4040

4141

42+
Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput)
43+
44+
4245
class Rpc_MethodOutput(BaseModel):
4346
data: str
47+
48+
49+
Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput)

tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,26 @@
55

66
from pydantic import TypeAdapter
77

8-
from replit_river.error_schema import RiverError
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
99
import replit_river as river
1010

1111

1212
from .needsEnum import (
1313
NeedsenumErrors,
14+
NeedsenumErrorsTypeAdapter,
1415
NeedsenumInput,
16+
NeedsenumInputTypeAdapter,
1517
NeedsenumOutput,
18+
NeedsenumOutputTypeAdapter,
1619
encode_NeedsenumInput,
1720
)
1821
from .needsEnumObject import (
1922
NeedsenumobjectErrors,
23+
NeedsenumobjectErrorsTypeAdapter,
2024
NeedsenumobjectInput,
25+
NeedsenumobjectInputTypeAdapter,
2126
NeedsenumobjectOutput,
27+
NeedsenumobjectOutputTypeAdapter,
2228
encode_NeedsenumobjectInput,
2329
)
2430

@@ -37,10 +43,10 @@ async def needsEnum(
3743
"needsEnum",
3844
input,
3945
lambda x: x,
40-
lambda x: TypeAdapter(NeedsenumOutput).validate_python(
46+
lambda x: NeedsenumOutputTypeAdapter.validate_python(
4147
x # type: ignore[arg-type]
4248
),
43-
lambda x: TypeAdapter(NeedsenumErrors).validate_python(
49+
lambda x: NeedsenumErrorsTypeAdapter.validate_python(
4450
x # type: ignore[arg-type]
4551
),
4652
timeout,
@@ -56,10 +62,10 @@ async def needsEnumObject(
5662
"needsEnumObject",
5763
input,
5864
encode_NeedsenumobjectInput,
59-
lambda x: TypeAdapter(NeedsenumobjectOutput).validate_python(
65+
lambda x: NeedsenumobjectOutputTypeAdapter.validate_python(
6066
x # type: ignore[arg-type]
6167
),
62-
lambda x: TypeAdapter(NeedsenumobjectErrors).validate_python(
68+
lambda x: NeedsenumobjectErrorsTypeAdapter.validate_python(
6369
x # type: ignore[arg-type]
6470
),
6571
timeout,

tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,19 @@
2626

2727
NeedsenumInput = Literal["in_first"] | Literal["in_second"]
2828
encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x
29+
30+
NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput)
31+
2932
NeedsenumOutput = Annotated[
3033
Literal["out_first"] | Literal["out_second"] | RiverUnknownValue,
3134
WrapValidator(translate_unknown_value),
3235
]
36+
37+
NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput)
38+
3339
NeedsenumErrors = Annotated[
3440
Literal["err_first"] | Literal["err_second"] | RiverUnknownValue,
3541
WrapValidator(translate_unknown_value),
3642
]
43+
44+
NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors)

0 commit comments

Comments
 (0)