Skip to content

Commit ea00715

Browse files
Modularizing client codegen
1 parent bcbe8d3 commit ea00715

File tree

1 file changed

+92
-45
lines changed

1 file changed

+92
-45
lines changed

replit_river/codegen/client.py

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import json
22
import re
3+
from pathlib import Path
34
from textwrap import dedent, indent
45
from typing import (
56
Any,
67
Dict,
78
List,
89
Literal,
10+
NewType,
911
Optional,
1012
OrderedDict,
1113
Sequence,
@@ -18,6 +20,11 @@
1820
import black
1921
from pydantic import BaseModel, Field, RootModel
2022

23+
ModuleName = NewType("ModuleName", str)
24+
ClassName = NewType("ClassName", str)
25+
FileContents = NewType("FileContents", str)
26+
HandshakeType = NewType("HandshakeType", str)
27+
2128
_NON_ALNUM_RE = re.compile(r"[^a-zA-Z0-9_]+")
2229
_LITERAL_RE = re.compile(r"^Literal\[(.+)\]$")
2330

@@ -485,13 +492,50 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
485492
return (prefix, chunks)
486493

487494

488-
def generate_individual_service(schema_name: str, handshake_type: str, schema: RiverService, input_base_class: Literal['TypedDict'] | Literal['BaseModel']) -> list[str]:
495+
def generate_common_client(
496+
client_name: str,
497+
handshake_type: HandshakeType,
498+
handshake_chunks: Sequence[str],
499+
modules: list[Tuple[ModuleName, ClassName]],
500+
) -> FileContents:
501+
chunks: list[str] = [FILE_HEADER]
502+
chunks.extend(
503+
[
504+
f"from .{model_name} import {class_name}"
505+
for model_name, class_name in modules
506+
]
507+
)
508+
chunks.extend(handshake_chunks)
509+
chunks.extend(
510+
[
511+
dedent(
512+
f"""\
513+
class {client_name}:
514+
def __init__(self, client: river.Client[{handshake_type}]):
515+
""".rstrip()
516+
)
517+
]
518+
)
519+
for module_name, class_name in modules:
520+
chunks.append(
521+
f" self.{module_name} = {class_name}(client)",
522+
)
523+
524+
return FileContents("\n".join(chunks))
525+
526+
527+
def generate_individual_service(
528+
schema_name: str,
529+
schema: RiverService,
530+
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
531+
) -> Tuple[Path, ModuleName, ClassName, FileContents]:
489532
serdes: list[str] = []
533+
class_name = ClassName(f"{schema_name.title()}Service")
490534
current_chunks: List[str] = [
491535
dedent(
492536
f"""\
493-
class {schema_name.title()}Service:
494-
def __init__(self, client: river.Client[{handshake_type}]):
537+
class {class_name}:
538+
def __init__(self, client: river.Client[Any]):
495539
self.client = client
496540
"""
497541
),
@@ -549,15 +593,13 @@ def __init__(self, client: river.Client[{handshake_type}]):
549593
""".rstrip()
550594

551595
# Init renderer
552-
if input_base_class == 'TypedDict' and init_type:
596+
if input_base_class == "TypedDict" and init_type:
553597
if is_literal(procedure.input):
554598
render_init_method = "lambda x: x"
555599
elif isinstance(
556600
procedure.input, RiverConcreteType
557601
) and procedure.input.type in ["array"]:
558-
assert init_type.startswith(
559-
"List["
560-
) # in case we change to list[...]
602+
assert init_type.startswith("List[") # in case we change to list[...]
561603
_init_type_name = init_type[len("List[") : -len("]")]
562604
render_init_method = (
563605
f"lambda xs: [encode_{_init_type_name}(x) for x in xs]"
@@ -575,15 +617,13 @@ def __init__(self, client: river.Client[{handshake_type}]):
575617
render_init_method = "lambda x: x"
576618

577619
# Input renderer
578-
if input_base_class == 'TypedDict':
620+
if input_base_class == "TypedDict":
579621
if is_literal(procedure.input):
580622
render_input_method = "lambda x: x"
581623
elif isinstance(
582624
procedure.input, RiverConcreteType
583625
) and procedure.input.type in ["array"]:
584-
assert input_type.startswith(
585-
"List["
586-
) # in case we change to list[...]
626+
assert input_type.startswith("List[") # in case we change to list[...]
587627
_input_type_name = input_type[len("List[") : -len("]")]
588628
render_input_method = (
589629
f"lambda xs: [encode_{_input_type_name}(x) for x in xs]"
@@ -760,45 +800,48 @@ async def {name}(
760800
)
761801

762802
current_chunks.append("")
763-
return current_chunks
803+
return (
804+
Path(f"{schema_name}.py"),
805+
ModuleName(schema_name),
806+
class_name,
807+
FileContents("\n".join([FILE_HEADER] + serdes + current_chunks)),
808+
)
764809

765810

766811
def generate_river_client_module(
767812
client_name: str,
768813
schema_root: RiverSchema,
769814
typed_dict_inputs: bool,
770-
) -> Sequence[str]:
771-
chunks: List[str] = [FILE_HEADER]
815+
) -> dict[Path, FileContents]:
816+
files: dict[Path, FileContents] = {}
772817

818+
# Negotiate handshake shape
819+
handshake_chunks: Sequence[str] = []
773820
if schema_root.handshakeSchema is not None:
774-
(handshake_type, handshake_chunks) = encode_type(
821+
(_handshake_type, handshake_chunks) = encode_type(
775822
schema_root.handshakeSchema, "HandshakeSchema", "BaseModel"
776823
)
777-
chunks.extend(handshake_chunks)
824+
handshake_type = HandshakeType(_handshake_type)
778825
else:
779-
handshake_type = "Literal[None]"
826+
handshake_type = HandshakeType("Literal[None]")
780827

781-
input_base_class: Literal['TypedDict'] | Literal['BaseModel'] = "TypedDict" if typed_dict_inputs else "BaseModel"
782-
for schema_name, schema in schema_root.services.items():
783-
current_chunks = generate_individual_service(schema_name, handshake_type, schema, input_base_class)
784-
chunks.extend(current_chunks)
785-
786-
chunks.extend(
787-
[
788-
dedent(
789-
f"""\
790-
class {client_name}:
791-
def __init__(self, client: river.Client[{handshake_type}]):
792-
""".rstrip()
793-
)
794-
]
828+
modules: list[Tuple[ModuleName, ClassName]] = []
829+
input_base_class: Literal["TypedDict"] | Literal["BaseModel"] = (
830+
"TypedDict" if typed_dict_inputs else "BaseModel"
795831
)
796832
for schema_name, schema in schema_root.services.items():
797-
chunks.append(
798-
f" self.{schema_name} = {schema_name.title()}Service(client)",
833+
path, module_name, class_name, file_contents = generate_individual_service(
834+
schema_name, schema, input_base_class
799835
)
836+
files[path] = file_contents
837+
modules.append((module_name, class_name))
838+
839+
main_contents = generate_common_client(
840+
client_name, handshake_type, handshake_chunks, modules
841+
)
842+
files[Path("__init__.py")] = main_contents
800843

801-
return chunks
844+
return files
802845

803846

804847
def schema_to_river_client_codegen(
@@ -810,14 +853,18 @@ def schema_to_river_client_codegen(
810853
"""Generates the lines of a River module."""
811854
with open(schema_path) as f:
812855
schemas = RiverSchemaFile(json.load(f))
813-
with open(target_path, "w") as f:
814-
s = "\n".join(
815-
generate_river_client_module(client_name, schemas.root, typed_dict_inputs)
816-
)
817-
try:
818-
f.write(
819-
black.format_str(s, mode=black.FileMode(string_normalization=False))
820-
)
821-
except:
822-
f.write(s)
823-
raise
856+
for subpath, contents in generate_river_client_module(
857+
client_name, schemas.root, typed_dict_inputs
858+
).items():
859+
module_path = Path(target_path).joinpath(subpath)
860+
module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True)
861+
with open(module_path, "w") as f:
862+
try:
863+
f.write(
864+
black.format_str(
865+
contents, mode=black.FileMode(string_normalization=False)
866+
)
867+
)
868+
except:
869+
f.write(contents)
870+
raise

0 commit comments

Comments
 (0)