Skip to content

Commit 4df01d1

Browse files
Fix server streaming handler not cancelled on client disconnect
When a client disconnects during a server streaming RPC, the async generator continued yielding indefinitely because receive() was never consulted after the initial request was consumed. Fix this by spawning a background task in the `EndpointServerStream` case that monitors receive() for http.disconnect. When detected, an Event is set; the streaming loop checks it between yields and raises ConnectError(CANCELED) if set. The response stream is also explicitly `aclose()`'d in the finally block so generator finally-clauses run promptly rather than being deferred to GC. Fixes #174. Signed-off-by: Stefan VanBuren <svanburen@buf.build>
1 parent 9cb72bd commit 4df01d1

File tree

2 files changed

+106
-13
lines changed

2 files changed

+106
-13
lines changed

src/connectrpc/_server_async.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import base64
4+
import contextlib
45
import functools
56
import inspect
67
from abc import ABC, abstractmethod
7-
from asyncio import CancelledError, sleep
8+
from asyncio import CancelledError, Event, create_task, sleep
89
from dataclasses import replace
910
from http import HTTPStatus
1011
from typing import TYPE_CHECKING, Generic, TypeVar, cast
@@ -385,6 +386,9 @@ async def _handle_stream(
385386
self._read_max_bytes,
386387
)
387388

389+
disconnect_detected: Event | None = None
390+
monitor_task = None
391+
388392
match endpoint:
389393
case EndpointUnary():
390394
request = await _consume_single_request(request_stream)
@@ -396,22 +400,46 @@ async def _handle_stream(
396400
case EndpointServerStream():
397401
request = await _consume_single_request(request_stream)
398402
response_stream = endpoint.function(request, ctx)
403+
404+
# The request has been fully consumed; monitor receive() for a
405+
# client disconnect so we can stop streaming promptly.
406+
disconnect_detected = Event()
407+
408+
async def _watch_for_disconnect() -> None:
409+
while True:
410+
msg = await receive()
411+
if msg["type"] == "http.disconnect":
412+
disconnect_detected.set()
413+
return
414+
415+
monitor_task = create_task(_watch_for_disconnect())
399416
case EndpointBidiStream():
400417
response_stream = endpoint.function(request_stream, ctx)
401418

402-
async for message in response_stream:
403-
# Don't send headers until the first message to allow logic a chance to add
404-
# response headers.
405-
if not sent_headers:
406-
await _send_stream_response_headers(
407-
send, protocol, codec, resp_compression.name(), ctx
419+
try:
420+
async for message in response_stream:
421+
if disconnect_detected is not None and disconnect_detected.is_set():
422+
raise ConnectError(Code.CANCELED, "Client disconnected")
423+
# Don't send headers until the first message to allow logic a chance to add
424+
# response headers.
425+
if not sent_headers:
426+
await _send_stream_response_headers(
427+
send, protocol, codec, resp_compression.name(), ctx
428+
)
429+
sent_headers = True
430+
431+
body = writer.write(message)
432+
await send(
433+
{"type": "http.response.body", "body": body, "more_body": True}
408434
)
409-
sent_headers = True
410-
411-
body = writer.write(message)
412-
await send(
413-
{"type": "http.response.body", "body": body, "more_body": True}
414-
)
435+
finally:
436+
# Explicitly close the stream so that any generator finally-blocks
437+
# run promptly (Python defers async-generator cleanup to GC otherwise).
438+
await response_stream.aclose()
439+
if monitor_task is not None:
440+
monitor_task.cancel()
441+
with contextlib.suppress(CancelledError, Exception):
442+
await monitor_task
415443
except CancelledError as e:
416444
raise ConnectError(Code.CANCELED, "Request was cancelled") from e
417445
except Exception as e:

test/test_roundtrip.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import struct
35
from typing import TYPE_CHECKING
46

57
import pytest
@@ -280,3 +282,66 @@ async def request_stream():
280282
else:
281283
assert len(requests) == 2
282284
assert len(responses) == 1
285+
286+
287+
@pytest.mark.asyncio
288+
async def test_server_stream_client_disconnect() -> None:
289+
"""Server streaming generator should be closed when the client disconnects.
290+
291+
Regression test for https://github.com/connectrpc/connect-python/issues/174.
292+
"""
293+
generator_closed = asyncio.Event()
294+
295+
class InfiniteHaberdasher(Haberdasher):
296+
async def make_similar_hats(self, request, ctx):
297+
try:
298+
while True:
299+
yield Hat(size=request.inches, color="green")
300+
await asyncio.sleep(0) # yield control to event loop
301+
finally:
302+
generator_closed.set()
303+
304+
app = HaberdasherASGIApplication(InfiniteHaberdasher())
305+
306+
# Encode a Connect protocol (application/connect+proto) request for Size(inches=10).
307+
request_bytes = Size(inches=10).SerializeToString()
308+
request_body = struct.pack(">BI", 0, len(request_bytes)) + request_bytes
309+
310+
disconnect_trigger = asyncio.Event()
311+
response_count = 0
312+
call_count = 0
313+
314+
async def receive():
315+
nonlocal call_count
316+
call_count += 1
317+
if call_count == 1:
318+
return {"type": "http.request", "body": request_body, "more_body": False}
319+
# Block until the test is ready to simulate a disconnect.
320+
await disconnect_trigger.wait()
321+
return {"type": "http.disconnect"}
322+
323+
async def send(message):
324+
nonlocal response_count
325+
if message.get("type") == "http.response.body" and message.get(
326+
"more_body", False
327+
):
328+
response_count += 1
329+
if response_count >= 3:
330+
disconnect_trigger.set()
331+
332+
scope = {
333+
"type": "http",
334+
"method": "POST",
335+
"path": "/connectrpc.example.Haberdasher/MakeSimilarHats",
336+
"query_string": b"",
337+
"root_path": "",
338+
"headers": [(b"content-type", b"application/connect+proto")],
339+
}
340+
341+
# Without the fix the app hangs forever (generator never stopped), causing a
342+
# TimeoutError here. With the fix it terminates promptly after the disconnect.
343+
await asyncio.wait_for(app(scope, receive, send), timeout=5.0)
344+
345+
assert generator_closed.is_set(), (
346+
"generator should be closed after client disconnect"
347+
)

0 commit comments

Comments
 (0)