Skip to content

Commit ecd1b17

Browse files
authored
Drop messages on the client for closed streams/subscriptions (#133)
Why === Each stream/subscription has a messages channel with a capacity of 128 messages. In our main receive loop, we push messages into the channel, blocking until the channel has room. This adds some backpressure, but becomes problematic if the stream is not making any progress. For example, the client could start a stream and then decide to cancel it and not read any of the messages. If the server sends >128 messages, it will fill up the stream's channel leading to a deadlock for the session. This will be more correctly fixed when river v2 support is landed, as that adds support for proper cancellation. In the meantime, we can close the channel when we know we are not going to be reading from it anymore, and then drop any messages destiined for a closed channel. rpc/upload are not affected because the server is only allowed to send 1 payload, and the channel has a buffer size of 1, so there will always be room. > [!Note] > A deadlock can still occur if the client holds a reference to the `AsyncGenerator` but doesn't service it. This is likely a client bug if it happens, but we can probably add some timeouts to putting messages in the stream channel for defense in depth. I'll do this as a followup as I need to think a little bit more about how to properly handle that case. This PR as-is should be a quick win for our usage since we shouldn't be holding references to async generators that we aren't also actively servicing. What changed ============ - Close stream aiochannel in a finalizer in stream/subscription impls - Ignore `ChannelClosed` errors when adding message to stream - Fix tests to use correct method kinds, this was causing subscription/upload RPCs to not work correctly in tests. Luckily things are fine on the server codegen side. Test plan ========= - Added a test which caused a deadlock on the client before this change, but works properly after this change.
1 parent 7d011c7 commit ecd1b17

File tree

5 files changed

+70
-6
lines changed

5 files changed

+70
-6
lines changed

replit_river/client_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ async def send_subscription(
234234
) from e
235235
except Exception as e:
236236
raise e
237+
finally:
238+
output.close()
237239

238240
async def send_stream(
239241
self,
@@ -335,6 +337,8 @@ async def _encode_stream() -> None:
335337
) from e
336338
except Exception as e:
337339
raise e
340+
finally:
341+
output.close()
338342

339343
async def send_close_stream(
340344
self,

replit_river/session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,11 @@ async def _add_msg_to_stream(
525525
return
526526
try:
527527
await stream.put(msg.payload)
528-
except (RuntimeError, ChannelClosed) as e:
528+
except ChannelClosed:
529+
# The client is no longer interested in this stream,
530+
# just drop the message.
531+
pass
532+
except RuntimeError as e:
529533
raise InvalidMessageException(e) from e
530534

531535
async def _remove_acked_messages_in_buffer(self) -> None:

tests/common_handlers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async def upload_handler(
3939

4040
basic_upload: HandlerMapping = {
4141
("test_service", "upload_method"): (
42-
"upload",
42+
"upload-stream",
4343
upload_method_handler(upload_handler, deserialize_request, serialize_response),
4444
),
4545
}
@@ -54,7 +54,7 @@ async def subscription_handler(
5454

5555
basic_subscription: HandlerMapping = {
5656
("test_service", "subscription_method"): (
57-
"subscription",
57+
"subscription-stream",
5858
subscription_method_handler(
5959
subscription_handler, deserialize_request, serialize_response
6060
),

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Mapping
1+
from typing import Any, Literal, Mapping
22

33
import nanoid
44
import pytest
@@ -16,7 +16,8 @@
1616
# Modular fixtures
1717
pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"]
1818

19-
HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]]
19+
HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]
20+
HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandler]]
2021

2122

2223
def transport_message(

tests/test_communication.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from typing import AsyncGenerator
44

55
import pytest
6+
from grpc.aio import grpc
67

78
from replit_river.client import Client
89
from replit_river.error_schema import RiverError
10+
from replit_river.rpc import subscription_method_handler
911
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
1012
from tests.common_handlers import (
1113
basic_rpc_method,
@@ -14,9 +16,12 @@
1416
basic_upload,
1517
)
1618
from tests.conftest import (
19+
HandlerMapping,
1720
deserialize_error,
21+
deserialize_request,
1822
deserialize_response,
1923
serialize_request,
24+
serialize_response,
2025
)
2126

2227

@@ -101,6 +106,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
101106
@pytest.mark.asyncio
102107
@pytest.mark.parametrize("handlers", [{**basic_subscription}])
103108
async def test_subscription_method(client: Client) -> None:
109+
messages = []
104110
async for response in client.send_subscription(
105111
"test_service",
106112
"subscription_method",
@@ -110,7 +116,8 @@ async def test_subscription_method(client: Client) -> None:
110116
deserialize_error,
111117
):
112118
assert isinstance(response, str)
113-
assert "Subscription message" in response
119+
messages.append(response)
120+
assert messages == [f"Subscription message {i} for Bob" for i in range(5)]
114121

115122

116123
@pytest.mark.asyncio
@@ -213,3 +220,51 @@ async def stream_data() -> AsyncGenerator[str, None]:
213220
"Stream response for Stream Data 1",
214221
"Stream response for Stream Data 2",
215222
]
223+
224+
225+
async def flood_subscription_handler(
226+
request: str, context: grpc.aio.ServicerContext
227+
) -> AsyncGenerator[str, None]:
228+
for i in range(1024):
229+
yield f"Subscription message {i} for {request}"
230+
231+
232+
flood_subscription: HandlerMapping = {
233+
("test_service", "flood_subscription_method"): (
234+
"subscription-stream",
235+
subscription_method_handler(
236+
flood_subscription_handler, deserialize_request, serialize_response
237+
),
238+
),
239+
}
240+
241+
242+
@pytest.mark.asyncio
243+
@pytest.mark.parametrize("handlers", [{**basic_rpc_method, **flood_subscription}])
244+
async def test_ignore_flood_subscription(client: Client) -> None:
245+
sub = client.send_subscription(
246+
"test_service",
247+
"flood_subscription_method",
248+
"Initial Subscription Data",
249+
serialize_request,
250+
deserialize_response,
251+
deserialize_error,
252+
)
253+
254+
# read one entry to start the subscription
255+
await sub.__anext__()
256+
# close the subscription so we can signal that we're not
257+
# interested in the rest of the subscription.
258+
await sub.aclose()
259+
260+
# ensure that subsequent RPCs still work
261+
response = await client.send_rpc(
262+
"test_service",
263+
"rpc_method",
264+
"Alice",
265+
serialize_request,
266+
deserialize_response,
267+
deserialize_error,
268+
timedelta(seconds=20),
269+
)
270+
assert response == "Hello, Alice!"

0 commit comments

Comments
 (0)