Skip to content

Commit 2f1bb2d

Browse files
Implement RPC timeouts
1 parent 88f4347 commit 2f1bb2d

File tree

6 files changed

+27
-2
lines changed

6 files changed

+27
-2
lines changed

replit_river/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
33
from contextlib import contextmanager
44
from typing import Any, Generator, Generic, Literal, Optional, Union
5+
from datetime import timedelta
56

67
from opentelemetry import trace
78
from opentelemetry.trace import Span, SpanKind, StatusCode
@@ -60,6 +61,7 @@ async def send_rpc(
6061
request_serializer: Callable[[RequestType], Any],
6162
response_deserializer: Callable[[Any], ResponseType],
6263
error_deserializer: Callable[[Any], ErrorType],
64+
timeout: timedelta,
6365
) -> ResponseType:
6466
with _trace_procedure("rpc", service_name, procedure_name) as span:
6567
session = await self._transport.get_or_create_session()
@@ -71,6 +73,7 @@ async def send_rpc(
7173
response_deserializer,
7274
error_deserializer,
7375
span,
76+
timeout,
7477
)
7578

7679
async def send_upload(

replit_river/client_session.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
from datetime import timedelta
13
import logging
24
from collections.abc import AsyncIterable, AsyncIterator
35
from typing import Any, Callable, Optional, Union
@@ -8,6 +10,7 @@
810
from opentelemetry.trace import Span
911

1012
from replit_river.error_schema import (
13+
ERROR_CODE_CANCEL,
1114
ERROR_CODE_STREAM_CLOSED,
1215
RiverException,
1316
RiverServiceException,
@@ -39,6 +42,7 @@ async def send_rpc(
3942
response_deserializer: Callable[[Any], ResponseType],
4043
error_deserializer: Callable[[Any], ErrorType],
4144
span: Span,
45+
timeout: timedelta,
4246
) -> ResponseType:
4347
"""Sends a single RPC request to the server.
4448
@@ -58,7 +62,19 @@ async def send_rpc(
5862
# Handle potential errors during communication
5963
try:
6064
try:
61-
response = await output.get()
65+
async with asyncio.timeout(int(timeout.total_seconds())):
66+
response = await output.get()
67+
except asyncio.CancelledError as e:
68+
# TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT
69+
await self.send_message(
70+
stream_id=stream_id,
71+
control_flags=STREAM_CLOSED_BIT,
72+
payload={"type": "CLOSE"},
73+
service_name=service_name,
74+
procedure_name=procedure_name,
75+
span=span,
76+
)
77+
raise RiverException(ERROR_CODE_CANCEL, str(e)) from e
6278
except ChannelClosed as e:
6379
raise RiverServiceException(
6480
ERROR_CODE_STREAM_CLOSED,

replit_river/codegen/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ def __init__(self, client: river.Client[Any]):
852852
async def {name}(
853853
self,
854854
input: {render_type_expr(input_type)},
855+
timeout: datetime.timedelta,
855856
) -> {render_type_expr(output_type)}:
856857
return await self.client.send_rpc(
857858
{repr(schema_name)},
@@ -860,6 +861,7 @@ async def {name}(
860861
{reindent(" ", render_input_method)},
861862
{reindent(" ", parse_output_method)},
862863
{reindent(" ", parse_error_method)},
864+
timeout,
863865
)
864866
""",
865867
)

replit_river/rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
]
5252
ACK_BIT = 0x0001
5353
STREAM_OPEN_BIT = 0x0002
54-
STREAM_CLOSED_BIT = 0x0004
54+
STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2
5555

5656
# these codes are retriable
5757
# if the server sends a response with one of these codes,

tests/test_communication.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from datetime import timedelta
23
from typing import AsyncGenerator
34

45
import pytest
@@ -18,6 +19,7 @@ async def test_rpc_method(client: Client) -> None:
1819
serialize_request,
1920
deserialize_response,
2021
deserialize_error,
22+
timedelta(seconds=20),
2123
)
2224
assert response == "Hello, Alice!"
2325

tests/test_opentelemetry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta
12
from typing import AsyncGenerator
23

34
import pytest
@@ -21,6 +22,7 @@ async def test_rpc_method_span(
2122
serialize_request,
2223
deserialize_response,
2324
deserialize_error,
25+
timedelta(seconds=20),
2426
)
2527
assert response == "Hello, Alice!"
2628
spans = span_exporter.get_finished_spans()

0 commit comments

Comments
 (0)