Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dev-dependencies = [

[tool.ruff]
lint.select = ["F", "E", "W", "I001"]
exclude = ["*/generated/*"]

# Should be kept in sync with mypy.ini in the project root.
# The VSCode mypy extension can only read /mypy.ini.
Expand Down
3 changes: 3 additions & 0 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Generator, Generic, Literal, Optional, Union

from opentelemetry import trace
Expand Down Expand Up @@ -60,6 +61,7 @@ async def send_rpc(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
timeout: timedelta,
) -> ResponseType:
with _trace_procedure("rpc", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
Expand All @@ -71,6 +73,7 @@ async def send_rpc(
response_deserializer,
error_deserializer,
span,
timeout,
)

async def send_upload(
Expand Down
18 changes: 17 additions & 1 deletion replit_river/client_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
from collections.abc import AsyncIterable, AsyncIterator
from datetime import timedelta
from typing import Any, Callable, Optional, Union

import nanoid # type: ignore
Expand All @@ -8,6 +10,7 @@
from opentelemetry.trace import Span

from replit_river.error_schema import (
ERROR_CODE_CANCEL,
ERROR_CODE_STREAM_CLOSED,
RiverException,
RiverServiceException,
Expand Down Expand Up @@ -39,6 +42,7 @@ async def send_rpc(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
timeout: timedelta,
) -> ResponseType:
"""Sends a single RPC request to the server.

Expand All @@ -58,7 +62,19 @@ async def send_rpc(
# Handle potential errors during communication
try:
try:
response = await output.get()
async with asyncio.timeout(int(timeout.total_seconds())):
response = await output.get()
except asyncio.TimeoutError as e:
# TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT
await self.send_message(
stream_id=stream_id,
control_flags=STREAM_CLOSED_BIT,
payload={"type": "CLOSE"},
service_name=service_name,
procedure_name=procedure_name,
span=span,
)
raise RiverException(ERROR_CODE_CANCEL, str(e)) from e
except ChannelClosed as e:
raise RiverServiceException(
ERROR_CODE_STREAM_CLOSED,
Expand Down
3 changes: 3 additions & 0 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime

from pydantic import TypeAdapter

Expand Down Expand Up @@ -857,6 +858,7 @@ def __init__(self, client: river.Client[Any]):
async def {name}(
self,
input: {render_type_expr(input_type)},
timeout: datetime.timedelta,
) -> {render_type_expr(output_type)}:
return await self.client.send_rpc(
{repr(schema_name)},
Expand All @@ -865,6 +867,7 @@ async def {name}(
{reindent(" ", render_input_method)},
{reindent(" ", parse_output_method)},
{reindent(" ", parse_error_method)},
timeout,
)
""",
)
Expand Down
2 changes: 1 addition & 1 deletion replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
]
ACK_BIT = 0x0001
STREAM_OPEN_BIT = 0x0002
STREAM_CLOSED_BIT = 0x0004
STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2

# these codes are retriable
# if the server sends a response with one of these codes,
Expand Down
15 changes: 8 additions & 7 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,14 @@ async def _handle_messages_from_ws(
)
await self._add_msg_to_stream(msg, stream)
else:
stream = await self._open_stream_and_call_handler(
msg, stream, tg
)
# TODO(dstewart) This looks like it opens a new call to handler
# on ever ws message, instead of demuxing and
# routing.
_stream = await self._open_stream_and_call_handler(msg, tg)
if not stream:
async with self._stream_lock:
self._streams[msg.streamId] = _stream
stream = _stream

if msg.controlFlags & STREAM_CLOSED_BIT != 0:
if stream:
Expand Down Expand Up @@ -457,7 +462,6 @@ async def close_websocket(
async def _open_stream_and_call_handler(
self,
msg: TransportMessage,
stream: Optional[Channel],
tg: Optional[asyncio.TaskGroup],
) -> Channel:
if not self._is_server:
Expand Down Expand Up @@ -496,9 +500,6 @@ async def _open_stream_and_call_handler(
await input_stream.put(msg.payload)
except (RuntimeError, ChannelClosed) as e:
raise InvalidMessageException(e) from e
if not stream:
async with self._stream_lock:
self._streams[msg.streamId] = input_stream
# Start the handler.
self._task_manager.create_task(
handler_func(msg.from_, input_stream, output_stream), tg
Expand Down
16 changes: 8 additions & 8 deletions replit_river/task_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Any, Optional, Set
from typing import Coroutine, Optional, Set

from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException

Expand All @@ -11,7 +11,7 @@ class BackgroundTaskManager:
"""Manages background tasks and logs exceptions."""

def __init__(self) -> None:
self.background_tasks: Set[asyncio.Task] = set()
self.background_tasks: Set[asyncio.Task[None]] = set()

async def cancel_all_tasks(self) -> None:
"""Asynchronously cancels all tasks managed by this instance."""
Expand All @@ -21,8 +21,8 @@ async def cancel_all_tasks(self) -> None:

@staticmethod
async def cancel_task(
task_to_remove: asyncio.Task[Any],
background_tasks: Set[asyncio.Task],
task_to_remove: asyncio.Task[None],
background_tasks: Set[asyncio.Task[None]],
) -> None:
"""Cancels a given task and ensures it is removed from the set of managed tasks.

Expand Down Expand Up @@ -50,8 +50,8 @@ async def cancel_task(

def _task_done_callback(
self,
task_to_remove: asyncio.Task[Any],
background_tasks: Set[asyncio.Task],
task_to_remove: asyncio.Task[None],
background_tasks: Set[asyncio.Task[None]],
) -> None:
"""Callback to be executed when a task is done. It removes the task from the set
and logs any exceptions.
Expand Down Expand Up @@ -83,8 +83,8 @@ def _task_done_callback(
)

def create_task(
self, fn: Any, tg: Optional[asyncio.TaskGroup] = None
) -> asyncio.Task:
self, fn: Coroutine[None, None, None], tg: Optional[asyncio.TaskGroup] = None
) -> asyncio.Task[None]:
"""Creates a task from a callable and adds it to the background tasks set.

Args:
Expand Down
13 changes: 13 additions & 0 deletions tests/codegen/rpc/generated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Code generated by river.codegen. DO NOT EDIT.
from pydantic import BaseModel
from typing import Literal

import replit_river as river


from .test_service import Test_ServiceService


class RpcClient:
def __init__(self, client: river.Client[Literal[None]]):
self.test_service = Test_ServiceService(client)
36 changes: 36 additions & 0 deletions tests/codegen/rpc/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
import replit_river as river


from .rpc_method import encode_Rpc_MethodInput, Rpc_MethodInput, Rpc_MethodOutput


class Test_ServiceService:
def __init__(self, client: river.Client[Any]):
self.client = client

async def rpc_method(
self,
input: Rpc_MethodInput,
timeout: datetime.timedelta,
) -> Rpc_MethodOutput:
return await self.client.send_rpc(
"test_service",
"rpc_method",
input,
encode_Rpc_MethodInput,
lambda x: TypeAdapter(Rpc_MethodOutput).validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(RiverError).validate_python(
x # type: ignore[arg-type]
),
timeout,
)
40 changes: 40 additions & 0 deletions tests/codegen/rpc/generated/test_service/rpc_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ruff: noqa
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
import datetime
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Mapping,
Union,
Tuple,
TypedDict,
)

from pydantic import BaseModel, Field, TypeAdapter
from replit_river.error_schema import RiverError

import replit_river as river


encode_Rpc_MethodInput: Callable[["Rpc_MethodInput"], Any] = lambda x: {
k: v
for (k, v) in (
{
"data": x.get("data"),
}
).items()
if v is not None
}


class Rpc_MethodInput(TypedDict):
data: str


class Rpc_MethodOutput(BaseModel):
data: str
32 changes: 32 additions & 0 deletions tests/codegen/rpc/schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"services": {
"test_service": {
"procedures": {
"rpc_method": {
"input": {
"type": "object",
"properties": {
"data": {
"type": "string"
}
},
"required": ["data"]
},
"output": {
"type": "object",
"properties": {
"data": {
"type": "string"
}
},
"required": ["data"]
},
"errors": {
"not": {}
},
"type": "rpc"
}
}
}
}
}
13 changes: 13 additions & 0 deletions tests/codegen/stream/generated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Code generated by river.codegen. DO NOT EDIT.
from pydantic import BaseModel
from typing import Literal

import replit_river as river


from .test_service import Test_ServiceService


class StreamClient:
def __init__(self, client: river.Client[Literal[None]]):
self.test_service = Test_ServiceService(client)
40 changes: 40 additions & 0 deletions tests/codegen/stream/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
import replit_river as river


from .stream_method import (
encode_Stream_MethodInput,
Stream_MethodOutput,
Stream_MethodInput,
)


class Test_ServiceService:
def __init__(self, client: river.Client[Any]):
self.client = client

async def stream_method(
self,
inputStream: AsyncIterable[Stream_MethodInput],
) -> AsyncIterator[Stream_MethodOutput | RiverError]:
return self.client.send_stream(
"test_service",
"stream_method",
None,
inputStream,
None,
encode_Stream_MethodInput,
lambda x: TypeAdapter(Stream_MethodOutput).validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(RiverError).validate_python(
x # type: ignore[arg-type]
),
)
Loading
Loading