Skip to content

Commit 89164f4

Browse files
Implement RPC timeouts
1 parent 23c556d commit 89164f4

File tree

6 files changed

+28
-2
lines changed

6 files changed

+28
-2
lines changed

replit_river/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
33
from contextlib import contextmanager
4+
from datetime import timedelta
45
from typing import Any, Generator, Generic, Literal, Optional, Union
56

67
from opentelemetry import trace
@@ -60,6 +61,7 @@ async def send_rpc(
6061
request_serializer: Callable[[RequestType], Any],
6162
response_deserializer: Callable[[Any], ResponseType],
6263
error_deserializer: Callable[[Any], ErrorType],
64+
timeout: timedelta,
6365
) -> ResponseType:
6466
with _trace_procedure("rpc", service_name, procedure_name) as span:
6567
session = await self._transport.get_or_create_session()
@@ -71,6 +73,7 @@ async def send_rpc(
7173
response_deserializer,
7274
error_deserializer,
7375
span,
76+
timeout,
7477
)
7578

7679
async def send_upload(

replit_river/client_session.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import asyncio
12
import logging
23
from collections.abc import AsyncIterable, AsyncIterator
4+
from datetime import timedelta
35
from typing import Any, Callable, Optional, Union
46

57
import nanoid # type: ignore
@@ -8,6 +10,7 @@
810
from opentelemetry.trace import Span
911

1012
from replit_river.error_schema import (
13+
ERROR_CODE_CANCEL,
1114
ERROR_CODE_STREAM_CLOSED,
1215
RiverException,
1316
RiverServiceException,
@@ -39,6 +42,7 @@ async def send_rpc(
3942
response_deserializer: Callable[[Any], ResponseType],
4043
error_deserializer: Callable[[Any], ErrorType],
4144
span: Span,
45+
timeout: timedelta,
4246
) -> ResponseType:
4347
"""Sends a single RPC request to the server.
4448
@@ -58,7 +62,19 @@ async def send_rpc(
5862
# Handle potential errors during communication
5963
try:
6064
try:
61-
response = await output.get()
65+
async with asyncio.timeout(int(timeout.total_seconds())):
66+
response = await output.get()
67+
except asyncio.CancelledError as e:
68+
# TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT
69+
await self.send_message(
70+
stream_id=stream_id,
71+
control_flags=STREAM_CLOSED_BIT,
72+
payload={"type": "CLOSE"},
73+
service_name=service_name,
74+
procedure_name=procedure_name,
75+
span=span,
76+
)
77+
raise RiverException(ERROR_CODE_CANCEL, str(e)) from e
6278
except ChannelClosed as e:
6379
raise RiverServiceException(
6480
ERROR_CODE_STREAM_CLOSED,

replit_river/codegen/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
# Code generated by river.codegen. DO NOT EDIT.
5757
from collections.abc import AsyncIterable, AsyncIterator
5858
from typing import Any
59+
import datetime
5960
6061
from pydantic import TypeAdapter
6162
@@ -855,6 +856,7 @@ def __init__(self, client: river.Client[Any]):
855856
async def {name}(
856857
self,
857858
input: {render_type_expr(input_type)},
859+
timeout: datetime.timedelta,
858860
) -> {render_type_expr(output_type)}:
859861
return await self.client.send_rpc(
860862
{repr(schema_name)},
@@ -863,6 +865,7 @@ async def {name}(
863865
{reindent(" ", render_input_method)},
864866
{reindent(" ", parse_output_method)},
865867
{reindent(" ", parse_error_method)},
868+
timeout,
866869
)
867870
""",
868871
)

replit_river/rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
]
5252
ACK_BIT = 0x0001
5353
STREAM_OPEN_BIT = 0x0002
54-
STREAM_CLOSED_BIT = 0x0004
54+
STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2
5555

5656
# these codes are retriable
5757
# if the server sends a response with one of these codes,

tests/test_communication.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from datetime import timedelta
23
from typing import AsyncGenerator
34

45
import pytest
@@ -29,6 +30,7 @@ async def test_rpc_method(client: Client) -> None:
2930
serialize_request,
3031
deserialize_response,
3132
deserialize_error,
33+
timedelta(seconds=20),
3234
)
3335
assert response == "Hello, Alice!"
3436

tests/test_opentelemetry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta
12
from typing import AsyncGenerator, AsyncIterator, Iterator
23

34
import grpc
@@ -38,6 +39,7 @@ async def test_rpc_method_span(
3839
serialize_request,
3940
deserialize_response,
4041
deserialize_error,
42+
timedelta(seconds=20),
4143
)
4244
assert response == "Hello, Alice!"
4345
spans = span_exporter.get_finished_spans()

0 commit comments

Comments
 (0)