6464
6565from pydantic import TypeAdapter
6666
67- from replit_river.error_schema import RiverError
68- RiverErrorTypeAdapter = TypeAdapter(RiverError)
67+ from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
6968import replit_river as river
7069
7170"""
@@ -763,26 +762,25 @@ def generate_individual_service(
763762 schema : RiverService ,
764763 input_base_class : Literal ["TypedDict" ] | Literal ["BaseModel" ],
765764) -> Tuple [ModuleName , ClassName , dict [RenderedPath , FileContents ]]:
765+ serdes : list [Tuple [list [TypeName ], list [ModuleName ], list [FileContents ]]] = []
766+
766767 def append_type_adapter_definition (
767768 type_adapter_name : TypeName ,
768769 _type : TypeExpression ,
769770 module_info : list [ModuleName ],
770771 ) -> None :
771772 rendered_type_expr = render_type_expr (_type )
773+ var_name = render_type_expr (type_adapter_name )
774+ var_type = f"TypeAdapter[{ rendered_type_expr } ]"
775+ var_value = f"TypeAdapter({ rendered_type_expr } )"
772776 serdes .append (
773777 (
774778 [type_adapter_name ],
775779 module_info ,
776- [
777- FileContents (
778- f"{ type_adapter_name .value } : TypeAdapter[{ rendered_type_expr } ] = "
779- f"TypeAdapter({ rendered_type_expr } )"
780- )
781- ],
780+ [FileContents (f"{ var_name } : { var_type } = { var_value } " )],
782781 )
783782 )
784783
785- serdes : list [Tuple [list [TypeName ], list [ModuleName ], list [FileContents ]]] = []
786784 class_name = ClassName (f"{ schema_name .title ()} Service" )
787785 current_chunks : List [str ] = [
788786 dedent (
@@ -819,7 +817,9 @@ def __init__(self, client: river.Client[Any]):
819817 permit_unknown_members = False ,
820818 )
821819 input_type_name = extract_inner_type (input_type )
822- input_type_type_adapter_name = TypeName (f"{ input_type_name .value } TypeAdapter" )
820+ input_type_type_adapter_name = TypeName (
821+ f"{ render_literal_type (input_type_name )} TypeAdapter"
822+ )
823823 serdes .append (
824824 (
825825 [extract_inner_type (input_type ), * encoder_names ],
@@ -845,7 +845,9 @@ def __init__(self, client: river.Client[Any]):
845845 output_chunks ,
846846 )
847847 )
848- output_type_type_adapter_name = TypeName (f"{ output_type_name .value } TypeAdapter" )
848+ output_type_type_adapter_name = TypeName (
849+ f"{ render_literal_type (output_type_name )} TypeAdapter"
850+ )
849851 append_type_adapter_definition (
850852 output_type_type_adapter_name , output_type , module_info
851853 )
@@ -869,7 +871,9 @@ def __init__(self, client: river.Client[Any]):
869871 error_type_name = TypeName ("RiverError" )
870872 error_type = error_type_name
871873
872- error_type_type_adapter_name = TypeName (f"{ error_type_name .value } TypeAdapter" )
874+ error_type_type_adapter_name = TypeName (
875+ f"{ render_literal_type (error_type_name )} TypeAdapter"
876+ )
873877 if error_type_type_adapter_name .value != "RiverErrorTypeAdapter" :
874878 if len (module_info ) == 0 :
875879 module_info = output_module_info
@@ -882,14 +886,16 @@ def __init__(self, client: river.Client[Any]):
882886 # the function strings in the branches below, otherwise `dedent`
883887 # will pick our indentation level for normalization, which will
884888 # break the "def" indentation presuppositions.
889+ ottd_name = render_literal_type (output_type_type_adapter_name )
885890 parse_output_method = f"""\
886- lambda x: { output_type_type_adapter_name . value }
891+ lambda x: { ottd_name }
887892 .validate_python(
888893 x # type: ignore[arg-type]
889894 )
890895 """
896+ ettd_name = render_literal_type (error_type_type_adapter_name )
891897 parse_error_method = f"""\
892- lambda x: { error_type_type_adapter_name . value }
898+ lambda x: { ettd_name }
893899 .validate_python(
894900 x # type: ignore[arg-type]
895901 )
@@ -920,8 +926,8 @@ def __init__(self, client: river.Client[Any]):
920926 init_type_type_adapter_name , init_type , module_info
921927 )
922928 render_init_method = f"""\
923- lambda x: { init_type_type_adapter_name . value } )
924- .validate_python
929+ lambda x: { render_type_expr ( init_type_type_adapter_name ) }
930+ .validate_python
925931 """
926932
927933 assert init_type is None or render_init_method , (
@@ -947,7 +953,7 @@ def __init__(self, client: river.Client[Any]):
947953 render_input_method = f"encode_{ render_literal_type (input_type )} "
948954 else :
949955 render_input_method = f"""\
950- lambda x: { input_type_type_adapter_name . value }
956+ lambda x: { render_type_expr ( input_type_type_adapter_name ) }
951957 .dump_python(
952958 x, # type: ignore[arg-type]
953959 by_alias=True,
0 commit comments