Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,15 @@ dev-dependencies = [
"types-nanoid>=2.0.0.20240601",
"pyright>=1.1.389",
"pytest-snapshot>=0.9.0",
"lint",
]

[tool.uv.workspace]
members = ["scripts/lint"]

[tool.uv.sources]
lint = { workspace = true }

[tool.ruff]
lint.select = ["F", "E", "W", "I001"]
exclude = ["*/generated/*", "*/snapshots/*"]
Expand Down
1 change: 1 addition & 0 deletions scripts/lint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
uv run lint, in lieu of https://github.com/astral-sh/uv/issues/5903
18 changes: 18 additions & 0 deletions scripts/lint/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[project]
name = "lint"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11, <4"
dependencies = []

[project.scripts]
lint = "lint:main"
format = "lint:main"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/lint"]
File renamed without changes.
28 changes: 16 additions & 12 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ListTypeExpr,
LiteralTypeExpr,
ModuleName,
NoneTypeExpr,
OpenUnionTypeExpr,
RenderedPath,
TypeExpression,
Expand Down Expand Up @@ -170,7 +171,7 @@ def encode_type(
encoder_name: TypeName | None = None # defining this up here to placate mypy
chunks: List[FileContents] = []
if isinstance(type, RiverNotType):
return (TypeName("None"), [], [], set())
return (NoneTypeExpr(), [], [], set())
elif isinstance(type, RiverUnionType):
typeddict_encoder = list[str]()
encoder_names: set[TypeName] = set()
Expand Down Expand Up @@ -379,17 +380,17 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
typeddict_encoder.append(
f"encode_{render_literal_type(inner_type_name)}(x)"
)
case DictTypeExpr(_):
raise ValueError(
"What does it mean to try and encode a dict in"
" this position?"
)
case LiteralTypeExpr(const):
typeddict_encoder.append(repr(const))
case TypeName(value):
typeddict_encoder.append(f"encode_{value}(x)")
case NoneTypeExpr():
typeddict_encoder.append("None")
case other:
typeddict_encoder.append(
f"encode_{render_literal_type(other)}(x)"
_o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = (
other
)
raise ValueError(f"What does it mean to have {_o2} here?")
if permit_unknown_members:
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
else:
Expand Down Expand Up @@ -471,7 +472,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
return (TypeName("bool"), [], [], set())
elif type.type == "null" or type.type == "undefined":
typeddict_encoder.append("None")
return (TypeName("None"), [], [], set())
return (NoneTypeExpr(), [], [], set())
elif type.type == "Date":
typeddict_encoder.append("TODO: dstewart")
return (TypeName("datetime.datetime"), [], [], set())
Expand Down Expand Up @@ -511,8 +512,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
)
case LiteralTypeExpr(const):
typeddict_encoder.append(repr(const))
case TypeName(value):
typeddict_encoder.append(f"encode_{value}(x)")
case other:
typeddict_encoder.append(f"encode_{render_literal_type(other)}(x)")
_o1: NoneTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = other
raise ValueError(f"What does it mean to have {_o1} here?")
return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names)
assert type.type == "object", type.type

Expand Down Expand Up @@ -823,7 +827,7 @@ def __init__(self, client: river.Client[Any]):
module_names,
permit_unknown_members=True,
)
if error_type == "None":
if isinstance(error_type, NoneTypeExpr):
error_type = TypeName("RiverError")
else:
serdes.append(
Expand Down Expand Up @@ -916,7 +920,7 @@ def __init__(self, client: river.Client[Any]):
f"Unable to derive the input encoder from: {input_type}"
)

if output_type == "None":
if isinstance(output_type, NoneTypeExpr):
parse_output_method = "lambda x: None"

if procedure.type == "rpc":
Expand Down
35 changes: 14 additions & 21 deletions src/replit_river/codegen/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")


@dataclass(frozen=True)
class NoneTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")


@dataclass(frozen=True)
class DictTypeExpr:
nested: "TypeExpression"
Expand Down Expand Up @@ -59,6 +65,7 @@ def __str__(self) -> str:

TypeExpression = (
TypeName
| NoneTypeExpr
| DictTypeExpr
| ListTypeExpr
| LiteralTypeExpr
Expand Down Expand Up @@ -86,6 +93,8 @@ def render_type_expr(value: TypeExpression) -> str:
)
case TypeName(name):
return name
case NoneTypeExpr():
return "None"
case other:
assert_never(other)

Expand All @@ -112,33 +121,17 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
)
case TypeName(name):
return TypeName(name)
case NoneTypeExpr():
raise ValueError(f"Attempting to extract from a literal 'None': {value}")
case other:
assert_never(other)


def ensure_literal_type(value: TypeExpression) -> TypeName:
match value:
case DictTypeExpr(_):
raise ValueError(
f"Unexpected expression when expecting a type name: {value}"
)
case ListTypeExpr(_):
raise ValueError(
f"Unexpected expression when expecting a type name: {value}"
)
case LiteralTypeExpr(_):
raise ValueError(
f"Unexpected expression when expecting a type name: {value}"
)
case UnionTypeExpr(_):
raise ValueError(
f"Unexpected expression when expecting a type name: {value}"
)
case OpenUnionTypeExpr(_):
raise ValueError(
f"Unexpected expression when expecting a type name: {value}"
)
case TypeName(name):
return TypeName(name)
case other:
assert_never(other)
raise ValueError(
f"Unexpected expression when expecting a type name: {other}"
)
13 changes: 13 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading