6565from pydantic import TypeAdapter
6666
6767from replit_river.error_schema import RiverError
68+ RiverErrorTypeAdapter = TypeAdapter(RiverError)
6869import replit_river as river
6970
7071"""
@@ -761,6 +762,7 @@ def generate_individual_service(
761762 schema_name : str ,
762763 schema : RiverService ,
763764 input_base_class : Literal ["TypedDict" ] | Literal ["BaseModel" ],
765+
764766) -> Tuple [ModuleName , ClassName , dict [RenderedPath , FileContents ]]:
765767 serdes : list [Tuple [list [TypeName ], list [ModuleName ], list [FileContents ]]] = []
766768 class_name = ClassName (f"{ schema_name .title ()} Service" )
@@ -798,27 +800,47 @@ def __init__(self, client: river.Client[Any]):
798800 module_names ,
799801 permit_unknown_members = False ,
800802 )
803+ input_type_name = extract_inner_type (input_type )
804+ input_type_type_adapter_name = TypeName (f"{ input_type_name .value } TypeAdapter" )
801805 serdes .append (
802806 (
803807 [extract_inner_type (input_type ), * encoder_names ],
804808 module_info ,
805809 input_chunks ,
806810 )
807811 )
812+ serdes .append (
813+ (
814+ [input_type_type_adapter_name ],
815+ module_info ,
816+ [f"{ input_type_type_adapter_name .value } = TypeAdapter({ render_type_expr (input_type )} ) # type: ignore" ]
817+ )
818+ )
808819 output_type , module_info , output_chunks , encoder_names = encode_type (
809820 procedure .output ,
810821 TypeName (f"{ name .title ()} Output" ),
811822 "BaseModel" ,
812823 module_names ,
813824 permit_unknown_members = True ,
814825 )
826+ output_type_name = extract_inner_type (output_type )
815827 serdes .append (
816828 (
817- [extract_inner_type ( output_type ) , * encoder_names ],
829+ [output_type_name , * encoder_names ],
818830 module_info ,
819831 output_chunks ,
820832 )
821833 )
834+ output_type_type_adapter_name = TypeName (f"{ output_type_name .value } TypeAdapter" )
835+ # print('appending %r, %r' % (output_type_type_adapter_name, module_info))
836+ serdes .append (
837+ (
838+ [output_type_type_adapter_name ],
839+ module_info ,
840+ [f"{ output_type_type_adapter_name .value } = TypeAdapter({ render_type_expr (output_type )} ) # type: ignore" ],
841+ )
842+ )
843+ output_module_info = module_info
822844 if procedure .errors :
823845 error_type , module_info , errors_chunks , encoder_names = encode_type (
824846 procedure .errors ,
@@ -827,28 +849,48 @@ def __init__(self, client: river.Client[Any]):
827849 module_names ,
828850 permit_unknown_members = True ,
829851 )
852+ # print('error type module_info: %r' % module_info)
830853 if isinstance (error_type , NoneTypeExpr ):
831- error_type = TypeName ("RiverError" )
854+ error_type_name = TypeName ("RiverError" )
855+ error_type = error_type_name
832856 else :
857+ error_type_name = extract_inner_type (error_type )
833858 serdes .append (
834- ([extract_inner_type ( error_type ) ], module_info , errors_chunks )
859+ ([error_type_name ], module_info , errors_chunks )
835860 )
861+
862+
836863 else :
837- error_type = TypeName ("RiverError" )
838- output_or_error_type = UnionTypeExpr ([output_type , error_type ])
864+ error_type_name = TypeName ("RiverError" )
865+ error_type = error_type_name
866+
867+ error_type_type_adapter_name = TypeName (f"{ error_type .value } TypeAdapter" )
868+ if error_type_type_adapter_name .value != "RiverErrorTypeAdapter" :
869+ # print('error type: %r, %r, %r' % (error_type_type_adapter_name, module_info, output_module_info))
870+ if len (module_info ) == 0 :
871+ module_info = output_module_info
872+ serdes .append (
873+ (
874+ [error_type_type_adapter_name ],
875+ module_info ,
876+ [f"{ error_type_type_adapter_name .value } = TypeAdapter({ render_type_expr (error_type )} ) # type: ignore" ],
877+ )
878+ )
879+ output_or_error_type = UnionTypeExpr ([output_type , error_type_name ])
880+
839881
840882 # NB: These strings must be indented to at least the same level of
841883 # the function strings in the branches below, otherwise `dedent`
842884 # will pick our indentation level for normalization, which will
843885 # break the "def" indentation presuppositions.
844886 parse_output_method = f"""\
845- lambda x: TypeAdapter( { render_type_expr ( output_type ) } )
887+ lambda x: { output_type_type_adapter_name . value }
846888 .validate_python(
847889 x # type: ignore[arg-type]
848890 )
849891 """
850892 parse_error_method = f"""\
851- lambda x: TypeAdapter( { render_type_expr ( error_type ) } )
893+ lambda x: { error_type_type_adapter_name . value }
852894 .validate_python(
853895 x # type: ignore[arg-type]
854896 )
@@ -871,8 +913,17 @@ def __init__(self, client: river.Client[Any]):
871913 else :
872914 render_init_method = f"encode_{ render_literal_type (init_type )} "
873915 else :
916+ init_type_name = extract_inner_type (init_type )
917+ init_type_type_adapter_name = TypeName (f"{ init_type_name .value } TypeAdapter" )
918+ serdes .append (
919+ (
920+ [init_type_type_adapter_name ],
921+ module_info ,
922+ [f"{ init_type_type_adapter_name .value } = TypeAdapter({ render_type_expr (init_type )} ) # type: ignore" ]
923+ )
924+ )
874925 render_init_method = f"""\
875- lambda x: TypeAdapter( { render_type_expr ( init_type ) } )
926+ lambda x: { init_type_type_adapter_name . name } )
876927 .validate_python
877928 """
878929
@@ -898,8 +949,9 @@ def __init__(self, client: river.Client[Any]):
898949 else :
899950 render_input_method = f"encode_{ render_literal_type (input_type )} "
900951 else :
952+
901953 render_input_method = f"""\
902- lambda x: TypeAdapter( { render_type_expr ( input_type ) } )
954+ lambda x: { input_type_type_adapter_name . value }
903955 .dump_python(
904956 x, # type: ignore[arg-type]
905957 by_alias=True,
0 commit comments