Skip to content

Commit edde572

Browse files
committed
updates
1 parent 64629dc commit edde572

File tree

1 file changed

+61
-9
lines changed

1 file changed

+61
-9
lines changed

src/replit_river/codegen/client.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pydantic import TypeAdapter
6666
6767
from replit_river.error_schema import RiverError
68+
RiverErrorTypeAdapter = TypeAdapter(RiverError)
6869
import replit_river as river
6970
7071
"""
@@ -761,6 +762,7 @@ def generate_individual_service(
761762
schema_name: str,
762763
schema: RiverService,
763764
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
765+
764766
) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
765767
serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
766768
class_name = ClassName(f"{schema_name.title()}Service")
@@ -798,27 +800,47 @@ def __init__(self, client: river.Client[Any]):
798800
module_names,
799801
permit_unknown_members=False,
800802
)
803+
input_type_name = extract_inner_type(input_type)
804+
input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter")
801805
serdes.append(
802806
(
803807
[extract_inner_type(input_type), *encoder_names],
804808
module_info,
805809
input_chunks,
806810
)
807811
)
812+
serdes.append(
813+
(
814+
[input_type_type_adapter_name],
815+
module_info,
816+
[f"{input_type_type_adapter_name.value} = TypeAdapter({render_type_expr(input_type)}) # type: ignore"]
817+
)
818+
)
808819
output_type, module_info, output_chunks, encoder_names = encode_type(
809820
procedure.output,
810821
TypeName(f"{name.title()}Output"),
811822
"BaseModel",
812823
module_names,
813824
permit_unknown_members=True,
814825
)
826+
output_type_name = extract_inner_type(output_type)
815827
serdes.append(
816828
(
817-
[extract_inner_type(output_type), *encoder_names],
829+
[output_type_name, *encoder_names],
818830
module_info,
819831
output_chunks,
820832
)
821833
)
834+
output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter")
835+
# print('appending %r, %r' % (output_type_type_adapter_name, module_info))
836+
serdes.append(
837+
(
838+
[output_type_type_adapter_name],
839+
module_info,
840+
[f"{output_type_type_adapter_name.value} = TypeAdapter({render_type_expr(output_type)}) # type: ignore"],
841+
)
842+
)
843+
output_module_info = module_info
822844
if procedure.errors:
823845
error_type, module_info, errors_chunks, encoder_names = encode_type(
824846
procedure.errors,
@@ -827,28 +849,48 @@ def __init__(self, client: river.Client[Any]):
827849
module_names,
828850
permit_unknown_members=True,
829851
)
852+
# print('error type module_info: %r' % module_info)
830853
if isinstance(error_type, NoneTypeExpr):
831-
error_type = TypeName("RiverError")
854+
error_type_name = TypeName("RiverError")
855+
error_type = error_type_name
832856
else:
857+
error_type_name = extract_inner_type(error_type)
833858
serdes.append(
834-
([extract_inner_type(error_type)], module_info, errors_chunks)
859+
([error_type_name], module_info, errors_chunks)
835860
)
861+
862+
836863
else:
837-
error_type = TypeName("RiverError")
838-
output_or_error_type = UnionTypeExpr([output_type, error_type])
864+
error_type_name = TypeName("RiverError")
865+
error_type = error_type_name
866+
867+
error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter")
868+
if error_type_type_adapter_name.value != "RiverErrorTypeAdapter":
869+
# print('error type: %r, %r, %r' % (error_type_type_adapter_name, module_info, output_module_info))
870+
if len(module_info) == 0:
871+
module_info = output_module_info
872+
serdes.append(
873+
(
874+
[error_type_type_adapter_name],
875+
module_info,
876+
[f"{error_type_type_adapter_name.value} = TypeAdapter({render_type_expr(error_type)}) # type: ignore"],
877+
)
878+
)
879+
output_or_error_type = UnionTypeExpr([output_type, error_type_name])
880+
839881

840882
# NB: These strings must be indented to at least the same level of
841883
# the function strings in the branches below, otherwise `dedent`
842884
# will pick our indentation level for normalization, which will
843885
# break the "def" indentation presuppositions.
844886
parse_output_method = f"""\
845-
lambda x: TypeAdapter({render_type_expr(output_type)})
887+
lambda x: {output_type_type_adapter_name.value}
846888
.validate_python(
847889
x # type: ignore[arg-type]
848890
)
849891
"""
850892
parse_error_method = f"""\
851-
lambda x: TypeAdapter({render_type_expr(error_type)})
893+
lambda x: {error_type_type_adapter_name.value}
852894
.validate_python(
853895
x # type: ignore[arg-type]
854896
)
@@ -871,8 +913,17 @@ def __init__(self, client: river.Client[Any]):
871913
else:
872914
render_init_method = f"encode_{render_literal_type(init_type)}"
873915
else:
916+
init_type_name = extract_inner_type(init_type)
917+
init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter")
918+
serdes.append(
919+
(
920+
[init_type_type_adapter_name],
921+
module_info,
922+
[f"{init_type_type_adapter_name.value} = TypeAdapter({render_type_expr(init_type)}) # type: ignore"]
923+
)
924+
)
874925
render_init_method = f"""\
875-
lambda x: TypeAdapter({render_type_expr(init_type)})
926+
lambda x: {init_type_type_adapter_name.name})
876927
.validate_python
877928
"""
878929

@@ -898,8 +949,9 @@ def __init__(self, client: river.Client[Any]):
898949
else:
899950
render_input_method = f"encode_{render_literal_type(input_type)}"
900951
else:
952+
901953
render_input_method = f"""\
902-
lambda x: TypeAdapter({render_type_expr(input_type)})
954+
lambda x: {input_type_type_adapter_name.value}
903955
.dump_python(
904956
x, # type: ignore[arg-type]
905957
by_alias=True,

0 commit comments

Comments
 (0)