Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,16 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
"""
)
current_chunks.append(
f" kind: {render_type_expr(type_name)} | None{value}"
f" kind: {
render_type_expr(
UnionTypeExpr(
[
type_name,
NoneTypeExpr(),
]
)
)
}{value}"
)
else:
value = ""
Expand All @@ -666,7 +675,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
reindent(
" ",
f"""\
{name}: NotRequired[{render_type_expr(type_name)}] | None
{name}: NotRequired[{
render_type_expr(
UnionTypeExpr([type_name, NoneTypeExpr()])
)
}]
""",
)
)
Expand All @@ -675,7 +688,16 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
reindent(
" ",
f"""\
{name}: {render_type_expr(type_name)} | None = None
{name}: {
render_type_expr(
UnionTypeExpr(
[
type_name,
NoneTypeExpr(),
]
)
)
} = None
""",
)
)
Expand Down Expand Up @@ -1246,6 +1268,8 @@ def schema_to_river_client_codegen(
stdout=subprocess.PIPE,
)
stdout, _ = popen.communicate(contents.encode())
if popen.returncode != 0:
f.write(contents)
f.write(stdout.decode("utf-8"))
except:
f.write(contents)
Expand Down
56 changes: 56 additions & 0 deletions src/replit_river/codegen/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,24 @@ class TypeName:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, TypeName) and other.value == self.value

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class NoneTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, NoneTypeExpr)

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class DictTypeExpr:
Expand All @@ -30,6 +42,12 @@ class DictTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, DictTypeExpr) and other.nested == self.nested

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class ListTypeExpr:
Expand All @@ -38,6 +56,12 @@ class ListTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, ListTypeExpr) and other.nested == self.nested

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class LiteralTypeExpr:
Expand All @@ -46,6 +70,12 @@ class LiteralTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, LiteralTypeExpr) and other.nested == self.nested

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class UnionTypeExpr:
Expand All @@ -54,6 +84,14 @@ class UnionTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, UnionTypeExpr) and set(other.nested) == set(
self.nested
)

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class OpenUnionTypeExpr:
Expand All @@ -62,6 +100,12 @@ class OpenUnionTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, OpenUnionTypeExpr) and other.union == self.union

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


TypeExpression = (
TypeName
Expand Down Expand Up @@ -117,13 +161,25 @@ def render_type_expr(value: TypeExpression) -> str:
literals.append(tpe)
else:
_other.append(tpe)

without_none: list[TypeExpression] = [
x for x in _other if not isinstance(x, NoneTypeExpr)
]
has_none = len(_other) > len(without_none)
_other = without_none

retval: str = " | ".join(render_type_expr(x) for x in _other)
if literals:
_rendered: str = ", ".join(repr(x.nested) for x in literals)
if retval:
retval = f"Literal[{_rendered}] | {retval}"
else:
retval = f"Literal[{_rendered}]"
if has_none:
if retval:
retval = f"{retval} | None"
else:
retval = "None"
return retval
case OpenUnionTypeExpr(inner):
return (
Expand Down
1 change: 0 additions & 1 deletion src/replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)

import grpc
import grpc.aio
from aiochannel import Channel, ChannelClosed
from opentelemetry.propagators.textmap import Setter
from pydantic import BaseModel, ConfigDict, Field
Expand Down
40 changes: 40 additions & 0 deletions tests/codegen/snapshot/codegen_snapshot_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from io import StringIO
from pathlib import Path
from typing import Callable, TextIO

from pytest_snapshot.plugin import Snapshot

from replit_river.codegen.client import schema_to_river_client_codegen


class UnclosableStringIO(StringIO):
def close(self) -> None:
pass


def validate_codegen(
*,
snapshot: Snapshot,
read_schema: Callable[[], TextIO],
target_path: str,
client_name: str,
) -> None:
snapshot.snapshot_dir = "tests/codegen/snapshot/snapshots"
files: dict[Path, UnclosableStringIO] = {}

def file_opener(path: Path) -> TextIO:
buffer = UnclosableStringIO()
assert path not in files, "Codegen attempted to write to the same file twice!"
files[path] = buffer
return buffer

schema_to_river_client_codegen(
read_schema=read_schema,
target_path=target_path,
client_name=client_name,
file_opener=file_opener,
typed_dict_inputs=True,
)
for path, file in files.items():
file.seek(0)
snapshot.assert_match(file.read(), Path(snapshot.snapshot_dir, path))
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Code generated by river.codegen. DO NOT EDIT.
from pydantic import BaseModel
from typing import Literal

import replit_river as river


from .test_service import Test_ServiceService


class PathologicalClient:
def __init__(self, client: river.Client[Literal[None]]):
self.test_service = Test_ServiceService(client)
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


from .pathological_method import (
Pathological_MethodInput,
Pathological_MethodInputTypeAdapter,
encode_Pathological_MethodInput,
encode_Pathological_MethodInputObj_Boolean,
encode_Pathological_MethodInputObj_Date,
encode_Pathological_MethodInputObj_Integer,
encode_Pathological_MethodInputObj_Null,
encode_Pathological_MethodInputObj_Number,
encode_Pathological_MethodInputObj_String,
encode_Pathological_MethodInputObj_Uint8Array,
encode_Pathological_MethodInputObj_Undefined,
encode_Pathological_MethodInputReq_Obj_Boolean,
encode_Pathological_MethodInputReq_Obj_Date,
encode_Pathological_MethodInputReq_Obj_Integer,
encode_Pathological_MethodInputReq_Obj_Null,
encode_Pathological_MethodInputReq_Obj_Number,
encode_Pathological_MethodInputReq_Obj_String,
encode_Pathological_MethodInputReq_Obj_Uint8Array,
encode_Pathological_MethodInputReq_Obj_Undefined,
)

boolTypeAdapter: TypeAdapter[Any] = TypeAdapter(bool)


class Test_ServiceService:
def __init__(self, client: river.Client[Any]):
self.client = client

async def pathological_method(
self,
input: Pathological_MethodInput,
timeout: datetime.timedelta,
) -> bool:
return await self.client.send_rpc(
"test_service",
"pathological_method",
input,
encode_Pathological_MethodInput,
lambda x: boolTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: RiverErrorTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
)
Loading
Loading