Skip to content

Commit e2eba79

Browse files
[client codegen] Unknown enumeration values, attempt 2 (#135)
Why === While technically the previously emitted deserializers were functional, the ergonomics were basically unusable. I fixed the local development flow so I could complete the whole feature end-to-end this time, confirming that it works. What changed ============ - Wrapping unknown enumerator parameters in a newly exposed `RiverUnknownValue`, emitted via a newly exposed `raise_unknown` It would be preferred to have had the `Annotated[...]` directly on underlying models instead of leaking the unknown status into the top-level enumerations, but that was going to require more surgery. The client migration from this encoding to that encoding will be very easy. Test plan ========= _Describe what you did to test this change to a level of detail that allows your reviewer to test it_
1 parent 32006fb commit e2eba79

File tree

10 files changed

+82
-64
lines changed

10 files changed

+82
-64
lines changed

src/replit_river/client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
from opentelemetry import trace
99
from opentelemetry.trace import Span, SpanKind, Status, StatusCode
10+
from pydantic import (
11+
BaseModel,
12+
ValidationInfo,
13+
)
1014

1115
from replit_river.client_transport import ClientTransport
1216
from replit_river.error_schema import RiverError, RiverException
@@ -27,6 +31,21 @@
2731
tracer = trace.get_tracer(__name__)
2832

2933

34+
@dataclass(frozen=True)
35+
class RiverUnknownValue(BaseModel):
36+
tag: Literal["RiverUnknownValue"]
37+
value: Any
38+
39+
40+
def translate_unknown_value(
41+
value: Any, handler: Callable[[Any], Any], info: ValidationInfo
42+
) -> Any | RiverUnknownValue:
43+
try:
44+
return handler(value)
45+
except Exception:
46+
return RiverUnknownValue(tag="RiverUnknownValue", value=value)
47+
48+
3049
class Client(Generic[HandshakeMetadataType]):
3150
def __init__(
3251
self,

src/replit_river/codegen/client.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
ListTypeExpr,
3131
LiteralTypeExpr,
3232
ModuleName,
33+
OpenUnionTypeExpr,
3334
RenderedPath,
3435
TypeExpression,
3536
TypeName,
3637
UnionTypeExpr,
37-
UnknownTypeExpr,
3838
ensure_literal_type,
3939
extract_inner_type,
4040
render_type_expr,
@@ -83,15 +83,16 @@
8383
Literal,
8484
Optional,
8585
Mapping,
86-
NewType,
8786
NotRequired,
8887
Union,
8988
Tuple,
9089
TypedDict,
9190
)
91+
from typing_extensions import Annotated
9292
93-
from pydantic import BaseModel, Field, TypeAdapter
93+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
9494
from replit_river.error_schema import RiverError
95+
from replit_river.client import RiverUnknownValue, translate_unknown_value
9596
9697
import replit_river as river
9798
@@ -311,19 +312,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
311312
else
312313
""",
313314
)
315+
union: TypeExpression
314316
if permit_unknown_members:
315-
unknown_name = TypeName(f"{prefix}AnyOf__Unknown")
316-
chunks.append(
317-
FileContents(
318-
f"{unknown_name} = NewType({repr(unknown_name)}, object)"
319-
)
320-
)
321-
one_of.append(UnknownTypeExpr(unknown_name))
322-
chunks.append(
323-
FileContents(
324-
f"{prefix} = {render_type_expr(UnionTypeExpr(one_of))}"
325-
)
326-
)
317+
union = OpenUnionTypeExpr(UnionTypeExpr(one_of))
318+
else:
319+
union = UnionTypeExpr(one_of)
320+
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
327321
chunks.append(FileContents(""))
328322

329323
if base_model == "TypedDict":
@@ -386,16 +380,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
386380
f"encode_{ensure_literal_type(other)}(x)"
387381
)
388382
if permit_unknown_members:
389-
unknown_name = TypeName(f"{prefix}AnyOf__Unknown")
390-
chunks.append(
391-
FileContents(f"{unknown_name} = NewType({repr(unknown_name)}, object)")
392-
)
393-
any_of.append(UnknownTypeExpr(unknown_name))
383+
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
384+
else:
385+
union = UnionTypeExpr(any_of)
394386
if is_literal(type):
395387
typeddict_encoder = ["x"]
396-
chunks.append(
397-
FileContents(f"{prefix} = {render_type_expr(UnionTypeExpr(any_of))}")
398-
)
388+
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
399389
if base_model == "TypedDict":
400390
encoder_name = TypeName(f"encode_{prefix}")
401391
encoder_names.add(encoder_name)

src/replit_river/codegen/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def main() -> None:
3030
client = subparsers.add_parser(
3131
"client", help="Codegen a River client from JSON schema"
3232
)
33-
client.add_argument("--output", help="output file", required=True)
33+
client.add_argument("--output", help="output path", required=True)
3434
client.add_argument("--client-name", help="name of the class", required=True)
3535
client.add_argument(
3636
"--typed-dict-inputs",

src/replit_river/codegen/typing.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class UnionTypeExpr:
3131

3232

3333
@dataclass
34-
class UnknownTypeExpr:
35-
name: TypeName
34+
class OpenUnionTypeExpr:
35+
union: UnionTypeExpr
3636

3737

3838
TypeExpression = (
@@ -41,7 +41,7 @@ class UnknownTypeExpr:
4141
| ListTypeExpr
4242
| LiteralTypeExpr
4343
| UnionTypeExpr
44-
| UnknownTypeExpr
44+
| OpenUnionTypeExpr
4545
)
4646

4747

@@ -55,10 +55,15 @@ def render_type_expr(value: TypeExpression) -> str:
5555
return f"Literal[{repr(inner)}]"
5656
case UnionTypeExpr(inner):
5757
return " | ".join(render_type_expr(x) for x in inner)
58+
case OpenUnionTypeExpr(inner):
59+
return (
60+
"Annotated["
61+
f"{render_type_expr(inner)} | RiverUnknownValue,"
62+
"WrapValidator(translate_unknown_value)"
63+
"]"
64+
)
5865
case str(name):
5966
return TypeName(name)
60-
case UnknownTypeExpr(name):
61-
return TypeName(name)
6267
case other:
6368
assert_never(other)
6469

@@ -75,10 +80,12 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
7580
raise ValueError(
7681
f"Attempting to extract from a union, currently not possible: {value}"
7782
)
83+
case OpenUnionTypeExpr(_):
84+
raise ValueError(
85+
f"Attempting to extract from a union, currently not possible: {value}"
86+
)
7887
case str(name):
7988
return TypeName(name)
80-
case UnknownTypeExpr(name):
81-
return name
8289
case other:
8390
assert_never(other)
8491

@@ -101,9 +108,11 @@ def ensure_literal_type(value: TypeExpression) -> TypeName:
101108
raise ValueError(
102109
f"Unexpected expression when expecting a type name: {value}"
103110
)
111+
case OpenUnionTypeExpr(_):
112+
raise ValueError(
113+
f"Unexpected expression when expecting a type name: {value}"
114+
)
104115
case str(name):
105116
return TypeName(name)
106-
case UnknownTypeExpr(name):
107-
return name
108117
case other:
109118
assert_never(other)

tests/codegen/rpc/generated/test_service/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import replit_river as river
1010

1111

12-
from .rpc_method import encode_Rpc_MethodInput, Rpc_MethodInput, Rpc_MethodOutput
12+
from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput
1313

1414

1515
class Test_ServiceService:

tests/codegen/rpc/generated/test_service/rpc_method.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
Literal,
1111
Optional,
1212
Mapping,
13-
NewType,
1413
NotRequired,
1514
Union,
1615
Tuple,
1716
TypedDict,
1817
)
18+
from typing_extensions import Annotated
1919

20-
from pydantic import BaseModel, Field, TypeAdapter
20+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
2121
from replit_river.error_schema import RiverError
22+
from replit_river.client import RiverUnknownValue, translate_unknown_value
2223

2324
import replit_river as river
2425

tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,27 @@
1010
Literal,
1111
Optional,
1212
Mapping,
13-
NewType,
1413
NotRequired,
1514
Union,
1615
Tuple,
1716
TypedDict,
1817
)
18+
from typing_extensions import Annotated
1919

20-
from pydantic import BaseModel, Field, TypeAdapter
20+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
2121
from replit_river.error_schema import RiverError
22+
from replit_river.client import RiverUnknownValue, translate_unknown_value
2223

2324
import replit_river as river
2425

2526

2627
NeedsenumInput = Literal["in_first"] | Literal["in_second"]
2728
encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x
28-
NeedsenumOutputAnyOf__Unknown = NewType("NeedsenumOutputAnyOf__Unknown", object)
29-
NeedsenumOutput = (
30-
Literal["out_first"] | Literal["out_second"] | NeedsenumOutputAnyOf__Unknown
31-
)
32-
NeedsenumErrorsAnyOf__Unknown = NewType("NeedsenumErrorsAnyOf__Unknown", object)
33-
NeedsenumErrors = (
34-
Literal["err_first"] | Literal["err_second"] | NeedsenumErrorsAnyOf__Unknown
35-
)
29+
NeedsenumOutput = Annotated[
30+
Literal["out_first"] | Literal["out_second"] | RiverUnknownValue,
31+
WrapValidator(translate_unknown_value),
32+
]
33+
NeedsenumErrors = Annotated[
34+
Literal["err_first"] | Literal["err_second"] | RiverUnknownValue,
35+
WrapValidator(translate_unknown_value),
36+
]

tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
Literal,
1111
Optional,
1212
Mapping,
13-
NewType,
1413
NotRequired,
1514
Union,
1615
Tuple,
1716
TypedDict,
1817
)
18+
from typing_extensions import Annotated
1919

20-
from pydantic import BaseModel, Field, TypeAdapter
20+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
2121
from replit_river.error_schema import RiverError
22+
from replit_river.client import RiverUnknownValue, translate_unknown_value
2223

2324
import replit_river as river
2425

@@ -90,14 +91,12 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel):
9091
bar: int
9192

9293

93-
NeedsenumobjectOutputFooAnyOf__Unknown = NewType(
94-
"NeedsenumobjectOutputFooAnyOf__Unknown", object
95-
)
96-
NeedsenumobjectOutputFoo = (
94+
NeedsenumobjectOutputFoo = Annotated[
9795
NeedsenumobjectOutputFooOneOf_out_first
9896
| NeedsenumobjectOutputFooOneOf_out_second
99-
| NeedsenumobjectOutputFooAnyOf__Unknown
100-
)
97+
| RiverUnknownValue,
98+
WrapValidator(translate_unknown_value),
99+
]
101100

102101

103102
class NeedsenumobjectOutput(BaseModel):
@@ -112,14 +111,12 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError):
112111
borp: Optional[Literal["err_second"]] = None
113112

114113

115-
NeedsenumobjectErrorsFooAnyOf__Unknown = NewType(
116-
"NeedsenumobjectErrorsFooAnyOf__Unknown", object
117-
)
118-
NeedsenumobjectErrorsFoo = (
114+
NeedsenumobjectErrorsFoo = Annotated[
119115
NeedsenumobjectErrorsFooAnyOf_0
120116
| NeedsenumobjectErrorsFooAnyOf_1
121-
| NeedsenumobjectErrorsFooAnyOf__Unknown
122-
)
117+
| RiverUnknownValue,
118+
WrapValidator(translate_unknown_value),
119+
]
123120

124121

125122
class NeedsenumobjectErrors(RiverError):

tests/codegen/stream/generated/test_service/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111

1212
from .stream_method import (
13-
encode_Stream_MethodInput,
14-
Stream_MethodOutput,
1513
Stream_MethodInput,
14+
Stream_MethodOutput,
15+
encode_Stream_MethodInput,
1616
)
1717

1818

tests/codegen/stream/generated/test_service/stream_method.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
Literal,
1111
Optional,
1212
Mapping,
13-
NewType,
1413
NotRequired,
1514
Union,
1615
Tuple,
1716
TypedDict,
1817
)
18+
from typing_extensions import Annotated
1919

20-
from pydantic import BaseModel, Field, TypeAdapter
20+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
2121
from replit_river.error_schema import RiverError
22+
from replit_river.client import RiverUnknownValue, translate_unknown_value
2223

2324
import replit_river as river
2425

0 commit comments

Comments
 (0)