55from textwrap import dedent
66from typing import (
77 Any ,
8+ Callable ,
89 Dict ,
910 List ,
1011 Literal ,
1112 Optional ,
1213 OrderedDict ,
1314 Sequence ,
1415 Set ,
16+ TextIO ,
1517 Tuple ,
1618 Union ,
1719 cast ,
3234 TypeExpression ,
3335 TypeName ,
3436 UnionTypeExpr ,
37+ UnknownTypeExpr ,
3538 ensure_literal_type ,
3639 extract_inner_type ,
3740 render_type_expr ,
8083 Literal,
8184 Optional,
8285 Mapping,
86+ NewType,
8387 NotRequired,
8488 Union,
8589 Tuple,
@@ -160,6 +164,7 @@ def encode_type(
160164 prefix : TypeName ,
161165 base_model : str ,
162166 in_module : list [ModuleName ],
167+ permit_unknown_members : bool ,
163168) -> Tuple [TypeExpression , list [ModuleName ], list [FileContents ], set [TypeName ]]:
164169 encoder_name : Optional [str ] = None # defining this up here to placate mypy
165170 chunks : List [FileContents ] = []
@@ -256,6 +261,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
256261 TypeName (f"{ pfx } { i } " ),
257262 base_model ,
258263 in_module ,
264+ permit_unknown_members = permit_unknown_members ,
259265 )
260266 one_of .append (type_name )
261267 chunks .extend (contents )
@@ -283,7 +289,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
283289 else :
284290 oneof_t = oneof_ts [0 ]
285291 type_name , _ , contents , _ = encode_type (
286- oneof_t , TypeName (pfx ), base_model , in_module
292+ oneof_t ,
293+ TypeName (pfx ),
294+ base_model ,
295+ in_module ,
296+ permit_unknown_members = permit_unknown_members ,
287297 )
288298 one_of .append (type_name )
289299 chunks .extend (contents )
@@ -301,6 +311,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
301311 else
302312 """ ,
303313 )
314+ if permit_unknown_members :
315+ unknown_name = TypeName (f"{ prefix } AnyOf__Unknown" )
316+ chunks .append (
317+ FileContents (
318+ f"{ unknown_name } = NewType({ repr (unknown_name )} , object)"
319+ )
320+ )
321+ one_of .append (UnknownTypeExpr (unknown_name ))
304322 chunks .append (
305323 FileContents (
306324 f"{ prefix } = { render_type_expr (UnionTypeExpr (one_of ))} "
@@ -336,7 +354,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
336354 typeddict_encoder = []
337355 for i , t in enumerate (type .anyOf ):
338356 type_name , _ , contents , _ = encode_type (
339- t , TypeName (f"{ prefix } AnyOf_{ i } " ), base_model , in_module
357+ t ,
358+ TypeName (f"{ prefix } AnyOf_{ i } " ),
359+ base_model ,
360+ in_module ,
361+ permit_unknown_members = permit_unknown_members ,
340362 )
341363 any_of .append (type_name )
342364 chunks .extend (contents )
@@ -363,6 +385,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
363385 typeddict_encoder .append (
364386 f"encode_{ ensure_literal_type (other )} (x)"
365387 )
388+ if permit_unknown_members :
389+ unknown_name = TypeName (f"{ prefix } AnyOf__Unknown" )
390+ chunks .append (
391+ FileContents (f"{ unknown_name } = NewType({ repr (unknown_name )} , object)" )
392+ )
393+ any_of .append (UnknownTypeExpr (unknown_name ))
366394 if is_literal (type ):
367395 typeddict_encoder = ["x" ]
368396 chunks .append (
@@ -404,6 +432,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
404432 prefix ,
405433 base_model ,
406434 in_module ,
435+ permit_unknown_members = permit_unknown_members ,
407436 )
408437 elif isinstance (type , RiverConcreteType ):
409438 typeddict_encoder = list [str ]()
@@ -446,7 +475,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
446475 return (TypeName ("datetime.datetime" ), [], [], set ())
447476 elif type .type == "array" and type .items :
448477 type_name , module_info , type_chunks , encoder_names = encode_type (
449- type .items , prefix , base_model , in_module
478+ type .items ,
479+ prefix ,
480+ base_model ,
481+ in_module ,
482+ permit_unknown_members = permit_unknown_members ,
450483 )
451484 typeddict_encoder .append ("TODO: dstewart" )
452485 return (ListTypeExpr (type_name ), module_info , type_chunks , encoder_names )
@@ -460,6 +493,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
460493 prefix ,
461494 base_model ,
462495 in_module ,
496+ permit_unknown_members = permit_unknown_members ,
463497 )
464498 # TODO(dstewart): This structure changed since we were incorrectly leaking
465499 # ListTypeExprs into codegen. This generated code is
@@ -494,7 +528,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
494528 ) in sorted (list (type .properties .items ()), key = lambda xs : xs [0 ]):
495529 typeddict_encoder .append (f"{ repr (name )} :" )
496530 type_name , _ , contents , _ = encode_type (
497- prop , TypeName (prefix + name .title ()), base_model , in_module
531+ prop ,
532+ TypeName (prefix + name .title ()),
533+ base_model ,
534+ in_module ,
535+ permit_unknown_members = permit_unknown_members ,
498536 )
499537 encoder_name = None
500538 chunks .extend (contents )
@@ -685,7 +723,7 @@ def generate_common_client(
685723 chunks .extend (
686724 [
687725 f"from .{ model_name } import { class_name } "
688- for model_name , class_name in modules
726+ for model_name , class_name in sorted ( modules , key = lambda kv : kv [ 1 ])
689727 ]
690728 )
691729 chunks .extend (handshake_chunks )
@@ -732,6 +770,7 @@ def __init__(self, client: river.Client[Any]):
732770 TypeName (f"{ name .title ()} Init" ),
733771 input_base_class ,
734772 module_names ,
773+ permit_unknown_members = False ,
735774 )
736775 serdes .append (
737776 (
@@ -745,6 +784,7 @@ def __init__(self, client: river.Client[Any]):
745784 TypeName (f"{ name .title ()} Input" ),
746785 input_base_class ,
747786 module_names ,
787+ permit_unknown_members = False ,
748788 )
749789 serdes .append (
750790 (
@@ -758,6 +798,7 @@ def __init__(self, client: river.Client[Any]):
758798 TypeName (f"{ name .title ()} Output" ),
759799 "BaseModel" ,
760800 module_names ,
801+ permit_unknown_members = True ,
761802 )
762803 serdes .append (
763804 (
@@ -772,6 +813,7 @@ def __init__(self, client: river.Client[Any]):
772813 TypeName (f"{ name .title ()} Errors" ),
773814 "RiverError" ,
774815 module_names ,
816+ permit_unknown_members = True ,
775817 )
776818 if error_type == "None" :
777819 error_type = TypeName ("RiverError" )
@@ -822,9 +864,9 @@ def __init__(self, client: river.Client[Any]):
822864 .validate_python
823865 """
824866
825- assert (
826- init_type is None or render_init_method
827- ), f"Unable to derive the init encoder from: { input_type } "
867+ assert init_type is None or render_init_method , (
868+ f"Unable to derive the init encoder from: { input_type } "
869+ )
828870
829871 # Input renderer
830872 render_input_method : Optional [str ] = None
@@ -862,9 +904,9 @@ def __init__(self, client: river.Client[Any]):
862904 ):
863905 render_input_method = "lambda x: x"
864906
865- assert (
866- render_input_method
867- ), f"Unable to derive the input encoder from: { input_type } "
907+ assert render_input_method , (
908+ f"Unable to derive the input encoder from: { input_type } "
909+ )
868910
869911 if output_type == "None" :
870912 parse_output_method = "lambda x: None"
@@ -1038,7 +1080,7 @@ async def {name}(
10381080 emitted_files [file_path ] = FileContents ("\n " .join ([existing ] + contents ))
10391081
10401082 rendered_imports = [
1041- f"from .{ dotted_modules } import { ', ' .join (names )} "
1083+ f"from .{ dotted_modules } import { ', ' .join (sorted ( names ) )} "
10421084 for dotted_modules , names in imports .items ()
10431085 ]
10441086
@@ -1063,7 +1105,11 @@ def generate_river_client_module(
10631105 handshake_chunks : list [str ] = []
10641106 if schema_root .handshakeSchema is not None :
10651107 _handshake_type , _ , contents , _ = encode_type (
1066- schema_root .handshakeSchema , TypeName ("HandshakeSchema" ), "BaseModel" , []
1108+ schema_root .handshakeSchema ,
1109+ TypeName ("HandshakeSchema" ),
1110+ "BaseModel" ,
1111+ [],
1112+ permit_unknown_members = False ,
10671113 )
10681114 handshake_chunks .extend (contents )
10691115 handshake_type = HandshakeType (render_type_expr (_handshake_type ))
@@ -1090,25 +1136,29 @@ def generate_river_client_module(
10901136
10911137
10921138def schema_to_river_client_codegen (
1093- schema_path : str ,
1139+ read_schema : Callable [[], TextIO ] ,
10941140 target_path : str ,
10951141 client_name : str ,
10961142 typed_dict_inputs : bool ,
1143+ file_opener : Callable [[Path ], TextIO ],
10971144) -> None :
10981145 """Generates the lines of a River module."""
1099- with open ( schema_path ) as f :
1146+ with read_schema ( ) as f :
11001147 schemas = RiverSchemaFile (json .load (f ))
11011148 for subpath , contents in generate_river_client_module (
11021149 client_name , schemas .root , typed_dict_inputs
11031150 ).items ():
11041151 module_path = Path (target_path ).joinpath (subpath )
11051152 module_path .parent .mkdir (mode = 0o755 , parents = True , exist_ok = True )
1106- with open (module_path , "w" ) as f :
1153+ with file_opener (module_path ) as f :
11071154 try :
11081155 popen = subprocess .Popen (
1109- ["ruff" , "format" , "-" ], stdin = subprocess .PIPE , stdout = f
1156+ ["ruff" , "format" , "-" ],
1157+ stdin = subprocess .PIPE ,
1158+ stdout = subprocess .PIPE ,
11101159 )
1111- popen .communicate (contents .encode ())
1160+ stdout , _ = popen .communicate (contents .encode ())
1161+ f .write (stdout .decode ("utf-8" ))
11121162 except :
11131163 f .write (contents )
11141164 raise
0 commit comments