11import collections
22import os .path
3+ from pathlib import Path
34import tempfile
45from textwrap import dedent
56from typing import DefaultDict , List , Sequence
1112from grpc_tools import protoc
1213
1314from replit_river .codegen .format import reindent
15+ from replit_river .codegen .typing import FileContents , RenderedPath
1416
1517
1618def to_camel_case (snake_str : str ) -> str :
@@ -300,8 +302,9 @@ def generate_river_module(
300302 module_name : str ,
301303 pb_module_name : str ,
302304 fds : descriptor_pb2 .FileDescriptorSet ,
303- ) -> Sequence [ str ]:
305+ ) -> dict [ RenderedPath , FileContents ]:
304306 """Generates the lines of a River module."""
307+ files : dict [RenderedPath , FileContents ] = {}
305308 chunks : List [str ] = [
306309 dedent (
307310 f"""\
@@ -388,7 +391,11 @@ def add_{service.name}Servicer_to_server(
388391 chunks .append (" }" )
389392 chunks .append (" server.add_rpc_handlers(rpc_method_handlers)" )
390393 chunks .append ("" )
391- return chunks
394+
395+ main_contents = FileContents ("\n " .join (chunks ))
396+ files [RenderedPath (str (Path ("__init__.py" )))] = main_contents
397+
398+ return files
392399
393400
394401def proto_to_river_server_codegen (
@@ -411,12 +418,21 @@ def proto_to_river_server_codegen(
411418 )
412419 with open (descriptor_path , "rb" ) as f :
413420 fds .ParseFromString (f .read ())
421+
414422 pb_module_name = os .path .splitext (os .path .basename (proto_path ))[0 ]
415- contents = black .format_str (
416- "\n " .join (generate_river_module (module_name , pb_module_name , fds )),
417- mode = black .FileMode (string_normalization = False ),
418- )
419423 os .makedirs (target_directory , exist_ok = True )
420- output_path = f"{ target_directory } /{ pb_module_name } _river.py"
421- with open (output_path , "w" ) as f :
422- f .write (contents )
424+
425+ for subpath , contents in generate_river_module (module_name , pb_module_name , fds ).items ():
426+ module_path = Path (target_directory ).joinpath (subpath )
427+ module_path .parent .mkdir (mode = 0o755 , parents = True , exist_ok = True )
428+
429+ with open (module_path , "w" ) as f :
430+ try :
431+ f .write (
432+ black .format_str (
433+ contents , mode = black .FileMode (string_normalization = False )
434+ )
435+ )
436+ except :
437+ f .write (contents )
438+ raise
0 commit comments