11import json
22import re
3+ from pathlib import Path
34from textwrap import dedent , indent
45from typing import (
56 Any ,
67 Dict ,
78 List ,
89 Literal ,
10+ NewType ,
911 Optional ,
1012 OrderedDict ,
1113 Sequence ,
1820import black
1921from pydantic import BaseModel , Field , RootModel
2022
23+ ModuleName = NewType ("ModuleName" , str )
24+ ClassName = NewType ("ClassName" , str )
25+ FileContents = NewType ("FileContents" , str )
26+ HandshakeType = NewType ("HandshakeType" , str )
27+
2128_NON_ALNUM_RE = re .compile (r"[^a-zA-Z0-9_]+" )
2229_LITERAL_RE = re .compile (r"^Literal\[(.+)\]$" )
2330
@@ -485,13 +492,50 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
485492 return (prefix , chunks )
486493
487494
488- def generate_individual_service (schema_name : str , handshake_type : str , schema : RiverService , input_base_class : Literal ['TypedDict' ] | Literal ['BaseModel' ]) -> list [str ]:
495+ def generate_common_client (
496+ client_name : str ,
497+ handshake_type : HandshakeType ,
498+ handshake_chunks : Sequence [str ],
499+ modules : list [Tuple [ModuleName , ClassName ]],
500+ ) -> FileContents :
501+ chunks : list [str ] = [FILE_HEADER ]
502+ chunks .extend (
503+ [
504+ f"from .{ model_name } import { class_name } "
505+ for model_name , class_name in modules
506+ ]
507+ )
508+ chunks .extend (handshake_chunks )
509+ chunks .extend (
510+ [
511+ dedent (
512+ f"""\
513+ class { client_name } :
514+ def __init__(self, client: river.Client[{ handshake_type } ]):
515+ """ .rstrip ()
516+ )
517+ ]
518+ )
519+ for module_name , class_name in modules :
520+ chunks .append (
521+ f" self.{ module_name } = { class_name } (client)" ,
522+ )
523+
524+ return FileContents ("\n " .join (chunks ))
525+
526+
527+ def generate_individual_service (
528+ schema_name : str ,
529+ schema : RiverService ,
530+ input_base_class : Literal ["TypedDict" ] | Literal ["BaseModel" ],
531+ ) -> Tuple [Path , ModuleName , ClassName , FileContents ]:
489532 serdes : list [str ] = []
533+ class_name = ClassName (f"{ schema_name .title ()} Service" )
490534 current_chunks : List [str ] = [
491535 dedent (
492536 f"""\
493- class { schema_name . title () } Service :
494- def __init__(self, client: river.Client[{ handshake_type } ]):
537+ class { class_name } :
538+ def __init__(self, client: river.Client[Any ]):
495539 self.client = client
496540 """
497541 ),
@@ -549,15 +593,13 @@ def __init__(self, client: river.Client[{handshake_type}]):
549593 """ .rstrip ()
550594
551595 # Init renderer
552- if input_base_class == ' TypedDict' and init_type :
596+ if input_base_class == " TypedDict" and init_type :
553597 if is_literal (procedure .input ):
554598 render_init_method = "lambda x: x"
555599 elif isinstance (
556600 procedure .input , RiverConcreteType
557601 ) and procedure .input .type in ["array" ]:
558- assert init_type .startswith (
559- "List["
560- ) # in case we change to list[...]
602+ assert init_type .startswith ("List[" ) # in case we change to list[...]
561603 _init_type_name = init_type [len ("List[" ) : - len ("]" )]
562604 render_init_method = (
563605 f"lambda xs: [encode_{ _init_type_name } (x) for x in xs]"
@@ -575,15 +617,13 @@ def __init__(self, client: river.Client[{handshake_type}]):
575617 render_init_method = "lambda x: x"
576618
577619 # Input renderer
578- if input_base_class == ' TypedDict' :
620+ if input_base_class == " TypedDict" :
579621 if is_literal (procedure .input ):
580622 render_input_method = "lambda x: x"
581623 elif isinstance (
582624 procedure .input , RiverConcreteType
583625 ) and procedure .input .type in ["array" ]:
584- assert input_type .startswith (
585- "List["
586- ) # in case we change to list[...]
626+ assert input_type .startswith ("List[" ) # in case we change to list[...]
587627 _input_type_name = input_type [len ("List[" ) : - len ("]" )]
588628 render_input_method = (
589629 f"lambda xs: [encode_{ _input_type_name } (x) for x in xs]"
@@ -760,45 +800,48 @@ async def {name}(
760800 )
761801
762802 current_chunks .append ("" )
763- return current_chunks
803+ return (
804+ Path (f"{ schema_name } .py" ),
805+ ModuleName (schema_name ),
806+ class_name ,
807+ FileContents ("\n " .join ([FILE_HEADER ] + serdes + current_chunks )),
808+ )
764809
765810
766811def generate_river_client_module (
767812 client_name : str ,
768813 schema_root : RiverSchema ,
769814 typed_dict_inputs : bool ,
770- ) -> Sequence [ str ]:
771- chunks : List [ str ] = [ FILE_HEADER ]
815+ ) -> dict [ Path , FileContents ]:
816+ files : dict [ Path , FileContents ] = {}
772817
818+ # Negotiate handshake shape
819+ handshake_chunks : Sequence [str ] = []
773820 if schema_root .handshakeSchema is not None :
774- (handshake_type , handshake_chunks ) = encode_type (
821+ (_handshake_type , handshake_chunks ) = encode_type (
775822 schema_root .handshakeSchema , "HandshakeSchema" , "BaseModel"
776823 )
777- chunks . extend ( handshake_chunks )
824+ handshake_type = HandshakeType ( _handshake_type )
778825 else :
779- handshake_type = "Literal[None]"
826+ handshake_type = HandshakeType ( "Literal[None]" )
780827
781- input_base_class : Literal ['TypedDict' ] | Literal ['BaseModel' ] = "TypedDict" if typed_dict_inputs else "BaseModel"
782- for schema_name , schema in schema_root .services .items ():
783- current_chunks = generate_individual_service (schema_name , handshake_type , schema , input_base_class )
784- chunks .extend (current_chunks )
785-
786- chunks .extend (
787- [
788- dedent (
789- f"""\
790- class { client_name } :
791- def __init__(self, client: river.Client[{ handshake_type } ]):
792- """ .rstrip ()
793- )
794- ]
828+ modules : list [Tuple [ModuleName , ClassName ]] = []
829+ input_base_class : Literal ["TypedDict" ] | Literal ["BaseModel" ] = (
830+ "TypedDict" if typed_dict_inputs else "BaseModel"
795831 )
796832 for schema_name , schema in schema_root .services .items ():
797- chunks . append (
798- f" self. { schema_name } = { schema_name . title () } Service(client)" ,
833+ path , module_name , class_name , file_contents = generate_individual_service (
834+ schema_name , schema , input_base_class
799835 )
836+ files [path ] = file_contents
837+ modules .append ((module_name , class_name ))
838+
839+ main_contents = generate_common_client (
840+ client_name , handshake_type , handshake_chunks , modules
841+ )
842+ files [Path ("__init__.py" )] = main_contents
800843
801- return chunks
844+ return files
802845
803846
804847def schema_to_river_client_codegen (
@@ -810,14 +853,18 @@ def schema_to_river_client_codegen(
810853 """Generates the lines of a River module."""
811854 with open (schema_path ) as f :
812855 schemas = RiverSchemaFile (json .load (f ))
813- with open (target_path , "w" ) as f :
814- s = "\n " .join (
815- generate_river_client_module (client_name , schemas .root , typed_dict_inputs )
816- )
817- try :
818- f .write (
819- black .format_str (s , mode = black .FileMode (string_normalization = False ))
820- )
821- except :
822- f .write (s )
823- raise
856+ for subpath , contents in generate_river_client_module (
857+ client_name , schemas .root , typed_dict_inputs
858+ ).items ():
859+ module_path = Path (target_path ).joinpath (subpath )
860+ module_path .parent .mkdir (mode = 0o755 , parents = True , exist_ok = True )
861+ with open (module_path , "w" ) as f :
862+ try :
863+ f .write (
864+ black .format_str (
865+ contents , mode = black .FileMode (string_normalization = False )
866+ )
867+ )
868+ except :
869+ f .write (contents )
870+ raise
0 commit comments