Skip to content

Commit 98d29e8

Browse files
committed
wip
1 parent 64629dc commit 98d29e8

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

src/replit_river/codegen/client.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def generate_individual_service(
761761
schema_name: str,
762762
schema: RiverService,
763763
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
764+
764765
) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
765766
serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
766767
class_name = ClassName(f"{schema_name.title()}Service")
@@ -812,13 +813,22 @@ def __init__(self, client: river.Client[Any]):
812813
module_names,
813814
permit_unknown_members=True,
814815
)
816+
output_type_name = extract_inner_type(output_type)
815817
serdes.append(
816818
(
817-
[extract_inner_type(output_type), *encoder_names],
819+
[output_type_name, *encoder_names],
818820
module_info,
819821
output_chunks,
820822
)
821823
)
824+
output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter")
825+
serdes.append(
826+
(
827+
[output_type_type_adapter_name],
828+
module_info,
829+
[f"{output_type_type_adapter_name.value} = TypeAdapter({render_type_expr(output_type)}) # type: ignore"],
830+
)
831+
)
822832
if procedure.errors:
823833
error_type, module_info, errors_chunks, encoder_names = encode_type(
824834
procedure.errors,
@@ -828,27 +838,41 @@ def __init__(self, client: river.Client[Any]):
828838
permit_unknown_members=True,
829839
)
830840
if isinstance(error_type, NoneTypeExpr):
831-
error_type = TypeName("RiverError")
841+
error_type_name = TypeName("RiverError")
842+
error_type = error_type_name
832843
else:
844+
error_type_name = extract_inner_type(error_type)
833845
serdes.append(
834-
([extract_inner_type(error_type)], module_info, errors_chunks)
846+
([error_type_name], module_info, errors_chunks)
847+
)
848+
849+
error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter")
850+
serdes.append(
851+
(
852+
[error_type_type_adapter_name],
853+
module_info,
854+
[f"{error_type_type_adapter_name.value} = TypeAdapter({render_type_expr(error_type)}) # type: ignore"],
835855
)
856+
)
836857
else:
837-
error_type = TypeName("RiverError")
838-
output_or_error_type = UnionTypeExpr([output_type, error_type])
858+
error_type_name = TypeName("RiverError")
859+
860+
861+
output_or_error_type = UnionTypeExpr([output_type, error_type_name])
862+
839863

840864
# NB: These strings must be indented to at least the same level of
841865
# the function strings in the branches below, otherwise `dedent`
842866
# will pick our indentation level for normalization, which will
843867
# break the "def" indentation presuppositions.
844868
parse_output_method = f"""\
845-
lambda x: TypeAdapter({render_type_expr(output_type)})
869+
lambda x: {output_type_type_adapter_name.value}
846870
.validate_python(
847871
x # type: ignore[arg-type]
848872
)
849873
"""
850874
parse_error_method = f"""\
851-
lambda x: TypeAdapter({render_type_expr(error_type)})
875+
lambda x: {error_type_type_adapter_name.value}
852876
.validate_python(
853877
x # type: ignore[arg-type]
854878
)
@@ -871,8 +895,17 @@ def __init__(self, client: river.Client[Any]):
871895
else:
872896
render_init_method = f"encode_{render_literal_type(init_type)}"
873897
else:
898+
init_type_name = extract_inner_type(init_type)
899+
init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter")
900+
serdes.append(
901+
(
902+
[init_type_type_adapter_name],
903+
module_info,
904+
[f"{init_type_type_adapter_name.value} = TypeAdapter({render_type_expr(init_type)}) # type: ignore"]
905+
)
906+
)
874907
render_init_method = f"""\
875-
lambda x: TypeAdapter({render_type_expr(init_type)})
908+
lambda x: {init_type_type_adapter_name.name})
876909
.validate_python
877910
"""
878911

@@ -898,8 +931,17 @@ def __init__(self, client: river.Client[Any]):
898931
else:
899932
render_input_method = f"encode_{render_literal_type(input_type)}"
900933
else:
934+
input_type_name = extract_inner_type(input_type)
935+
input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter")
936+
serdes.append(
937+
(
938+
[input_type_type_adapter_name],
939+
module_info,
940+
[f"{input_type_type_adapter_name.value} = TypeAdapter({render_type_expr(input_type)}) # type: ignore"]
941+
)
942+
)
901943
render_input_method = f"""\
902-
lambda x: TypeAdapter({render_type_expr(input_type)})
944+
lambda x: {input_type_type_adapter_name.value}
903945
.dump_python(
904946
x, # type: ignore[arg-type]
905947
by_alias=True,

0 commit comments

Comments
 (0)