6464
6565from pydantic import TypeAdapter
6666
67- from replit_river.error_schema import RiverError
67+ from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
6868import replit_river as river
6969
7070"""
@@ -763,6 +763,27 @@ def generate_individual_service(
763763 input_base_class : Literal ["TypedDict" ] | Literal ["BaseModel" ],
764764) -> Tuple [ModuleName , ClassName , dict [RenderedPath , FileContents ]]:
765765 serdes : list [Tuple [list [TypeName ], list [ModuleName ], list [FileContents ]]] = []
766+
767+ def _type_adapter_definition (
768+ type_adapter_name : TypeName ,
769+ _type : TypeExpression ,
770+ module_info : list [ModuleName ],
771+ ) -> tuple [list [TypeName ], list [ModuleName ], list [FileContents ]]:
772+ rendered_type_expr = render_type_expr (_type )
773+ return (
774+ [type_adapter_name ],
775+ module_info ,
776+ [
777+ FileContents (
778+ dedent (f"""
779+ { render_type_expr (type_adapter_name )} : TypeAdapter[Any] = (
780+ TypeAdapter({ rendered_type_expr } )
781+ )
782+ """ )
783+ )
784+ ],
785+ )
786+
766787 class_name = ClassName (f"{ schema_name .title ()} Service" )
767788 current_chunks : List [str ] = [
768789 dedent (
@@ -798,27 +819,46 @@ def __init__(self, client: river.Client[Any]):
798819 module_names ,
799820 permit_unknown_members = False ,
800821 )
822+ input_type_name = extract_inner_type (input_type )
823+ input_type_type_adapter_name = TypeName (
824+ f"{ render_literal_type (input_type_name )} TypeAdapter"
825+ )
801826 serdes .append (
802827 (
803828 [extract_inner_type (input_type ), * encoder_names ],
804829 module_info ,
805830 input_chunks ,
806831 )
807832 )
833+ serdes .append (
834+ _type_adapter_definition (
835+ input_type_type_adapter_name , input_type , module_info
836+ )
837+ )
808838 output_type , module_info , output_chunks , encoder_names = encode_type (
809839 procedure .output ,
810840 TypeName (f"{ name .title ()} Output" ),
811841 "BaseModel" ,
812842 module_names ,
813843 permit_unknown_members = True ,
814844 )
845+ output_type_name = extract_inner_type (output_type )
815846 serdes .append (
816847 (
817- [extract_inner_type ( output_type ) , * encoder_names ],
848+ [output_type_name , * encoder_names ],
818849 module_info ,
819850 output_chunks ,
820851 )
821852 )
853+ output_type_type_adapter_name = TypeName (
854+ f"{ render_literal_type (output_type_name )} TypeAdapter"
855+ )
856+ serdes .append (
857+ _type_adapter_definition (
858+ output_type_type_adapter_name , output_type , module_info
859+ )
860+ )
861+ output_module_info = module_info
822862 if procedure .errors :
823863 error_type , module_info , errors_chunks , encoder_names = encode_type (
824864 procedure .errors ,
@@ -828,27 +868,43 @@ def __init__(self, client: river.Client[Any]):
828868 permit_unknown_members = True ,
829869 )
830870 if isinstance (error_type , NoneTypeExpr ):
831- error_type = TypeName ("RiverError" )
871+ error_type_name = TypeName ("RiverError" )
872+ error_type = error_type_name
832873 else :
833- serdes . append (
834- ([ extract_inner_type ( error_type ) ], module_info , errors_chunks )
835- )
874+ error_type_name = extract_inner_type ( error_type )
875+ serdes . append (([ error_type_name ], module_info , errors_chunks ) )
876+
836877 else :
837- error_type = TypeName ("RiverError" )
838- output_or_error_type = UnionTypeExpr ([output_type , error_type ])
878+ error_type_name = TypeName ("RiverError" )
879+ error_type = error_type_name
880+
881+ error_type_type_adapter_name = TypeName (
882+ f"{ render_literal_type (error_type_name )} TypeAdapter"
883+ )
884+ if error_type_type_adapter_name .value != "RiverErrorTypeAdapter" :
885+ if len (module_info ) == 0 :
886+ module_info = output_module_info
887+ serdes .append (
888+ _type_adapter_definition (
889+ error_type_type_adapter_name , error_type , module_info
890+ )
891+ )
892+ output_or_error_type = UnionTypeExpr ([output_type , error_type_name ])
839893
840894 # NB: These strings must be indented to at least the same level of
841895 # the function strings in the branches below, otherwise `dedent`
842896 # will pick our indentation level for normalization, which will
843897 # break the "def" indentation presuppositions.
898+ output_type_adapter = render_literal_type (output_type_type_adapter_name )
844899 parse_output_method = f"""\
845- lambda x: TypeAdapter( { render_type_expr ( output_type ) } )
900+ lambda x: { output_type_adapter }
846901 .validate_python(
847902 x # type: ignore[arg-type]
848903 )
849904 """
905+ error_type_adapter = render_literal_type (error_type_type_adapter_name )
850906 parse_error_method = f"""\
851- lambda x: TypeAdapter( { render_type_expr ( error_type ) } )
907+ lambda x: { error_type_adapter }
852908 .validate_python(
853909 x # type: ignore[arg-type]
854910 )
@@ -871,9 +927,18 @@ def __init__(self, client: river.Client[Any]):
871927 else :
872928 render_init_method = f"encode_{ render_literal_type (init_type )} "
873929 else :
930+ init_type_name = extract_inner_type (init_type )
931+ init_type_type_adapter_name = TypeName (
932+ f"{ init_type_name .value } TypeAdapter"
933+ )
934+ serdes .append (
935+ _type_adapter_definition (
936+ init_type_type_adapter_name , init_type , module_info
937+ )
938+ )
874939 render_init_method = f"""\
875- lambda x: TypeAdapter( { render_type_expr (init_type ) } )
876- .validate_python
940+ lambda x: { render_type_expr (init_type_type_adapter_name ) }
941+ .validate_python
877942 """
878943
879944 assert init_type is None or render_init_method , (
@@ -889,17 +954,17 @@ def __init__(self, client: river.Client[Any]):
889954 procedure .input , RiverConcreteType
890955 ) and procedure .input .type in ["array" ]:
891956 match input_type :
892- case ListTypeExpr (input_type_name ):
957+ case ListTypeExpr (list_type ):
893958 render_input_method = f"""\
894959 lambda xs: [
895- encode_{ render_literal_type (input_type_name )} (x) for x in xs
960+ encode_{ render_literal_type (list_type )} (x) for x in xs
896961 ]
897962 """
898963 else :
899964 render_input_method = f"encode_{ render_literal_type (input_type )} "
900965 else :
901966 render_input_method = f"""\
902- lambda x: TypeAdapter( { render_type_expr (input_type ) } )
967+ lambda x: { render_type_expr (input_type_type_adapter_name ) }
903968 .dump_python(
904969 x, # type: ignore[arg-type]
905970 by_alias=True,
0 commit comments