Skip to content

Commit 8774a61

Browse files
Fix server streaming handler not cancelled on client disconnect
When a client disconnects during a server streaming RPC, the async generator continued yielding indefinitely because receive() was never consulted after the initial request was consumed. Fix this by spawning a background task in the `EndpointServerStream` case that monitors receive() for http.disconnect. When detected, an Event is set; the streaming loop checks it between yields and raises ConnectError(CANCELED) if set. The response stream is also explicitly `aclose()`'d in the finally block so generator finally-clauses run promptly rather than being deferred to GC. Fixes #174. Signed-off-by: Stefan VanBuren <svanburen@buf.build>
1 parent 9cb72bd commit 8774a61

2 files changed

Lines changed: 105 additions & 13 deletions

File tree

src/connectrpc/_server_async.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import functools
55
import inspect
66
from abc import ABC, abstractmethod
7-
from asyncio import CancelledError, sleep
7+
from asyncio import CancelledError, Event, create_task, sleep
88
from dataclasses import replace
99
from http import HTTPStatus
1010
from typing import TYPE_CHECKING, Generic, TypeVar, cast
@@ -385,6 +385,9 @@ async def _handle_stream(
385385
self._read_max_bytes,
386386
)
387387

388+
disconnect_detected: Event | None = None
389+
monitor_task = None
390+
388391
match endpoint:
389392
case EndpointUnary():
390393
request = await _consume_single_request(request_stream)
@@ -396,22 +399,48 @@ async def _handle_stream(
396399
case EndpointServerStream():
397400
request = await _consume_single_request(request_stream)
398401
response_stream = endpoint.function(request, ctx)
402+
403+
# The request has been fully consumed; monitor receive() for a
404+
# client disconnect so we can stop streaming promptly.
405+
disconnect_detected = Event()
406+
407+
async def _watch_for_disconnect() -> None:
408+
while True:
409+
msg = await receive()
410+
if msg["type"] == "http.disconnect":
411+
disconnect_detected.set()
412+
return
413+
414+
monitor_task = create_task(_watch_for_disconnect())
399415
case EndpointBidiStream():
400416
response_stream = endpoint.function(request_stream, ctx)
401417

402-
async for message in response_stream:
403-
# Don't send headers until the first message to allow logic a chance to add
404-
# response headers.
405-
if not sent_headers:
406-
await _send_stream_response_headers(
407-
send, protocol, codec, resp_compression.name(), ctx
418+
try:
419+
async for message in response_stream:
420+
if disconnect_detected is not None and disconnect_detected.is_set():
421+
raise ConnectError(Code.CANCELED, "Client disconnected")
422+
# Don't send headers until the first message to allow logic a chance to add
423+
# response headers.
424+
if not sent_headers:
425+
await _send_stream_response_headers(
426+
send, protocol, codec, resp_compression.name(), ctx
427+
)
428+
sent_headers = True
429+
430+
body = writer.write(message)
431+
await send(
432+
{"type": "http.response.body", "body": body, "more_body": True}
408433
)
409-
sent_headers = True
410-
411-
body = writer.write(message)
412-
await send(
413-
{"type": "http.response.body", "body": body, "more_body": True}
414-
)
434+
finally:
435+
# Explicitly close the stream so that any generator finally-blocks
436+
# run promptly (Python defers async-generator cleanup to GC otherwise).
437+
await response_stream.aclose()
438+
if monitor_task is not None:
439+
monitor_task.cancel()
440+
try:
441+
await monitor_task
442+
except (CancelledError, Exception):
443+
pass
415444
except CancelledError as e:
416445
raise ConnectError(Code.CANCELED, "Request was cancelled") from e
417446
except Exception as e:

test/test_roundtrip.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import struct
35
from typing import TYPE_CHECKING
46

57
import pytest
@@ -280,3 +282,64 @@ async def request_stream():
280282
else:
281283
assert len(requests) == 2
282284
assert len(responses) == 1
285+
286+
287+
@pytest.mark.asyncio
288+
async def test_server_stream_client_disconnect() -> None:
289+
"""Server streaming generator should be closed when the client disconnects.
290+
291+
Regression test for https://github.com/connectrpc/connect-python/issues/174.
292+
"""
293+
generator_closed = asyncio.Event()
294+
295+
class InfiniteHaberdasher(Haberdasher):
296+
async def make_similar_hats(self, request, ctx):
297+
try:
298+
while True:
299+
yield Hat(size=request.inches, color="green")
300+
await asyncio.sleep(0) # yield control to event loop
301+
finally:
302+
generator_closed.set()
303+
304+
app = HaberdasherASGIApplication(InfiniteHaberdasher())
305+
306+
# Encode a Connect protocol (application/connect+proto) request for Size(inches=10).
307+
request_bytes = Size(inches=10).SerializeToString()
308+
request_body = struct.pack(">BI", 0, len(request_bytes)) + request_bytes
309+
310+
disconnect_trigger = asyncio.Event()
311+
response_count = 0
312+
call_count = 0
313+
314+
async def receive():
315+
nonlocal call_count
316+
call_count += 1
317+
if call_count == 1:
318+
return {"type": "http.request", "body": request_body, "more_body": False}
319+
# Block until the test is ready to simulate a disconnect.
320+
await disconnect_trigger.wait()
321+
return {"type": "http.disconnect"}
322+
323+
async def send(message):
324+
nonlocal response_count
325+
if message.get("type") == "http.response.body" and message.get("more_body", False):
326+
response_count += 1
327+
if response_count >= 3:
328+
disconnect_trigger.set()
329+
330+
scope = {
331+
"type": "http",
332+
"method": "POST",
333+
"path": "/connectrpc.example.Haberdasher/MakeSimilarHats",
334+
"query_string": b"",
335+
"root_path": "",
336+
"headers": [
337+
(b"content-type", b"application/connect+proto"),
338+
],
339+
}
340+
341+
# Without the fix the app hangs forever (generator never stopped), causing a
342+
# TimeoutError here. With the fix it terminates promptly after the disconnect.
343+
await asyncio.wait_for(app(scope, receive, send), timeout=5.0)
344+
345+
assert generator_closed.is_set(), "generator should be closed after client disconnect"

0 commit comments

Comments
 (0)