Skip to content

Commit 8dcd439

Browse files
committed
regened the code
1 parent c77679e commit 8dcd439

File tree

6 files changed

+32
-30
lines changed

6 files changed

+32
-30
lines changed

src/replit_river/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
)
2222

2323
from .rpc import (
24-
ErrorType,
2524
InitType,
2625
RequestType,
2726
ResponseType,
@@ -129,7 +128,7 @@ async def send_subscription(
129128
request_serializer: Callable[[RequestType], Any],
130129
response_deserializer: Callable[[Any], ResponseType],
131130
error_deserializer: Callable[[Any], Any],
132-
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
131+
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
133132
with _trace_procedure(
134133
"subscription", service_name, procedure_name
135134
) as span_handle:
@@ -157,7 +156,7 @@ async def send_stream(
157156
request_serializer: Callable[[RequestType], Any],
158157
response_deserializer: Callable[[Any], ResponseType],
159158
error_deserializer: Callable[[Any], Any],
160-
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
159+
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
161160
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
162161
session = await self._transport.get_or_create_session()
163162
async for msg in session.send_stream(

src/replit_river/codegen/client.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@
6464
6565
from pydantic import TypeAdapter
6666
67-
from replit_river.error_schema import RiverError
68-
RiverErrorTypeAdapter = TypeAdapter(RiverError)
67+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
6968
import replit_river as river
7069
7170
"""
@@ -763,26 +762,25 @@ def generate_individual_service(
763762
schema: RiverService,
764763
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
765764
) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
765+
serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
766+
766767
def append_type_adapter_definition(
767768
type_adapter_name: TypeName,
768769
_type: TypeExpression,
769770
module_info: list[ModuleName],
770771
) -> None:
771772
rendered_type_expr = render_type_expr(_type)
773+
var_name = render_type_expr(type_adapter_name)
774+
var_type = f"TypeAdapter[{rendered_type_expr}]"
775+
var_value = f"TypeAdapter({rendered_type_expr})"
772776
serdes.append(
773777
(
774778
[type_adapter_name],
775779
module_info,
776-
[
777-
FileContents(
778-
f"{type_adapter_name.value}: TypeAdapter[{rendered_type_expr}] = "
779-
f"TypeAdapter({rendered_type_expr})"
780-
)
781-
],
780+
[FileContents(f"{var_name}: {var_type} = {var_value}")],
782781
)
783782
)
784783

785-
serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
786784
class_name = ClassName(f"{schema_name.title()}Service")
787785
current_chunks: List[str] = [
788786
dedent(
@@ -819,7 +817,9 @@ def __init__(self, client: river.Client[Any]):
819817
permit_unknown_members=False,
820818
)
821819
input_type_name = extract_inner_type(input_type)
822-
input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter")
820+
input_type_type_adapter_name = TypeName(
821+
f"{render_literal_type(input_type_name)}TypeAdapter"
822+
)
823823
serdes.append(
824824
(
825825
[extract_inner_type(input_type), *encoder_names],
@@ -845,7 +845,9 @@ def __init__(self, client: river.Client[Any]):
845845
output_chunks,
846846
)
847847
)
848-
output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter")
848+
output_type_type_adapter_name = TypeName(
849+
f"{render_literal_type(output_type_name)}TypeAdapter"
850+
)
849851
append_type_adapter_definition(
850852
output_type_type_adapter_name, output_type, module_info
851853
)
@@ -869,7 +871,9 @@ def __init__(self, client: river.Client[Any]):
869871
error_type_name = TypeName("RiverError")
870872
error_type = error_type_name
871873

872-
error_type_type_adapter_name = TypeName(f"{error_type_name.value}TypeAdapter")
874+
error_type_type_adapter_name = TypeName(
875+
f"{render_literal_type(error_type_name)}TypeAdapter"
876+
)
873877
if error_type_type_adapter_name.value != "RiverErrorTypeAdapter":
874878
if len(module_info) == 0:
875879
module_info = output_module_info
@@ -882,14 +886,16 @@ def __init__(self, client: river.Client[Any]):
882886
# the function strings in the branches below, otherwise `dedent`
883887
# will pick our indentation level for normalization, which will
884888
# break the "def" indentation presuppositions.
889+
ottd_name = render_literal_type(output_type_type_adapter_name)
885890
parse_output_method = f"""\
886-
lambda x: {output_type_type_adapter_name.value}
891+
lambda x: {ottd_name}
887892
.validate_python(
888893
x # type: ignore[arg-type]
889894
)
890895
"""
896+
ettd_name = render_literal_type(error_type_type_adapter_name)
891897
parse_error_method = f"""\
892-
lambda x: {error_type_type_adapter_name.value}
898+
lambda x: {ettd_name}
893899
.validate_python(
894900
x # type: ignore[arg-type]
895901
)
@@ -920,8 +926,8 @@ def __init__(self, client: river.Client[Any]):
920926
init_type_type_adapter_name, init_type, module_info
921927
)
922928
render_init_method = f"""\
923-
lambda x: {init_type_type_adapter_name.value})
924-
.validate_python
929+
lambda x: {render_type_expr(init_type_type_adapter_name)}
930+
.validate_python
925931
"""
926932

927933
assert init_type is None or render_init_method, (
@@ -947,7 +953,7 @@ def __init__(self, client: river.Client[Any]):
947953
render_input_method = f"encode_{render_literal_type(input_type)}"
948954
else:
949955
render_input_method = f"""\
950-
lambda x: {input_type_type_adapter_name.value}
956+
lambda x: {render_type_expr(input_type_type_adapter_name)}
951957
.dump_python(
952958
x, # type: ignore[arg-type]
953959
by_alias=True,

src/replit_river/error_schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, List, Optional
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, TypeAdapter
44

55
ERROR_CODE_STREAM_CLOSED = "stream_closed"
66
ERROR_HANDSHAKE = "handshake_failed"
@@ -25,6 +25,9 @@ class RiverError(BaseModel):
2525
message: str
2626

2727

28+
RiverErrorTypeAdapter = TypeAdapter(RiverError)
29+
30+
2831
class RiverException(Exception):
2932
"""Exception raised by the River server."""
3033

tests/codegen/rpc/generated/test_service/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
from pydantic import TypeAdapter
77

8-
from replit_river.error_schema import RiverError
9-
10-
RiverErrorTypeAdapter = TypeAdapter(RiverError)
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
119
import replit_river as river
1210

1311

tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
from pydantic import TypeAdapter
77

8-
from replit_river.error_schema import RiverError
9-
10-
RiverErrorTypeAdapter = TypeAdapter(RiverError)
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
119
import replit_river as river
1210

1311

tests/codegen/stream/generated/test_service/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
from pydantic import TypeAdapter
77

8-
from replit_river.error_schema import RiverError
9-
10-
RiverErrorTypeAdapter = TypeAdapter(RiverError)
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
119
import replit_river as river
1210

1311

0 commit comments

Comments
 (0)