Skip to content

Commit a3c84c4

Browse files
committed
refactor(server): simplify streamable HTTP request handling
1 parent bcda500 commit a3c84c4

File tree

1 file changed

+125
-101
lines changed

1 file changed

+125
-101
lines changed

src/mcp/server/streamable_http.py

Lines changed: 125 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import anyio
1919
import pydantic_core
20+
from anyio.abc import ObjectReceiveStream, ObjectSendStream
2021
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2122
from pydantic import ValidationError
2223
from sse_starlette import EventSourceResponse
@@ -427,6 +428,110 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se
427428
return False
428429
return True
429430

431+
async def _handle_post_request_json_mode(
432+
self,
433+
*,
434+
scope: Scope,
435+
request: Request,
436+
receive: Receive,
437+
send: Send,
438+
writer: ObjectSendStream[SessionMessage],
439+
message: JSONRPCRequest,
440+
request_id: str,
441+
request_stream_reader: ObjectReceiveStream[EventMessage],
442+
) -> None:
443+
metadata = ServerMessageMetadata(request_context=request)
444+
session_message = SessionMessage(message, metadata=metadata)
445+
await writer.send(session_message)
446+
try:
447+
# Process messages from the request-specific stream.
448+
response_message: JSONRPCResponse | JSONRPCError | None = None
449+
450+
async for event_message in request_stream_reader: # pragma: no branch
451+
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
452+
response_message = event_message.message
453+
break
454+
else: # pragma: no cover
455+
logger.debug("received: %s", event_message.message.method)
456+
457+
if response_message:
458+
response = self._create_json_response(response_message)
459+
await response(scope, receive, send)
460+
else: # pragma: no cover
461+
logger.error("No response message received before stream closed")
462+
response = self._create_error_response(
463+
"Error processing request: No response received",
464+
HTTPStatus.INTERNAL_SERVER_ERROR,
465+
)
466+
await response(scope, receive, send)
467+
except Exception: # pragma: no cover
468+
logger.exception("Error processing JSON response")
469+
response = self._create_error_response(
470+
"Error processing request",
471+
HTTPStatus.INTERNAL_SERVER_ERROR,
472+
INTERNAL_ERROR,
473+
)
474+
await response(scope, receive, send)
475+
finally:
476+
await self._clean_up_memory_streams(request_id)
477+
478+
async def _handle_post_request_sse_mode(
479+
self,
480+
*,
481+
scope: Scope,
482+
request: Request,
483+
receive: Receive,
484+
send: Send,
485+
writer: ObjectSendStream[SessionMessage],
486+
message: JSONRPCRequest,
487+
request_id: str,
488+
request_stream_reader: ObjectReceiveStream[EventMessage],
489+
protocol_version: str,
490+
) -> None: # pragma: no cover
491+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
492+
self._sse_stream_writers[request_id] = sse_stream_writer
493+
494+
async def sse_writer() -> None:
495+
try:
496+
async with sse_stream_writer, request_stream_reader:
497+
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)
498+
async for event_message in request_stream_reader:
499+
event_data = self._create_event_data(event_message)
500+
await sse_stream_writer.send(event_data)
501+
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
502+
break
503+
except anyio.ClosedResourceError:
504+
logger.debug("SSE stream closed by close_sse_stream()")
505+
except Exception:
506+
logger.exception("Error in SSE writer")
507+
finally:
508+
logger.debug("Closing SSE writer")
509+
self._sse_stream_writers.pop(request_id, None)
510+
await self._clean_up_memory_streams(request_id)
511+
512+
headers = {
513+
"Cache-Control": "no-cache, no-transform",
514+
"Connection": "keep-alive",
515+
"Content-Type": CONTENT_TYPE_SSE,
516+
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
517+
}
518+
response = EventSourceResponse(
519+
content=sse_stream_reader,
520+
data_sender_callable=sse_writer,
521+
headers=headers,
522+
)
523+
524+
try:
525+
async with anyio.create_task_group() as tg:
526+
tg.start_soon(response, scope, receive, send)
527+
session_message = self._create_session_message(message, request, request_id, protocol_version)
528+
await writer.send(session_message)
529+
except Exception:
530+
logger.exception("SSE response error")
531+
await sse_stream_writer.aclose()
532+
await sse_stream_reader.aclose()
533+
await self._clean_up_memory_streams(request_id)
534+
430535
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
431536
"""Handle POST requests containing JSON-RPC messages."""
432537
writer = self._read_stream_writer
@@ -527,110 +632,29 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
527632
request_stream_reader = self._request_streams[request_id][1]
528633

529634
if self.is_json_response_enabled:
530-
# Process the message
531-
metadata = ServerMessageMetadata(request_context=request)
532-
session_message = SessionMessage(message, metadata=metadata)
533-
await writer.send(session_message)
534-
try:
535-
# Process messages from the request-specific stream
536-
# We need to collect all messages until we get a response
537-
response_message = None
538-
539-
# Use similar approach to SSE writer for consistency
540-
async for event_message in request_stream_reader: # pragma: no branch
541-
# If it's a response, this is what we're waiting for
542-
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
543-
response_message = event_message.message
544-
break
545-
# For notifications and request, keep waiting
546-
else: # pragma: no cover
547-
logger.debug(f"received: {event_message.message.method}")
548-
549-
# At this point we should have a response
550-
if response_message:
551-
# Create JSON response
552-
response = self._create_json_response(response_message)
553-
await response(scope, receive, send)
554-
else: # pragma: no cover
555-
# This shouldn't happen in normal operation
556-
logger.error("No response message received before stream closed")
557-
response = self._create_error_response(
558-
"Error processing request: No response received",
559-
HTTPStatus.INTERNAL_SERVER_ERROR,
560-
)
561-
await response(scope, receive, send)
562-
except Exception: # pragma: no cover
563-
logger.exception("Error processing JSON response")
564-
response = self._create_error_response(
565-
"Error processing request",
566-
HTTPStatus.INTERNAL_SERVER_ERROR,
567-
INTERNAL_ERROR,
568-
)
569-
await response(scope, receive, send)
570-
finally:
571-
await self._clean_up_memory_streams(request_id)
635+
await self._handle_post_request_json_mode(
636+
scope=scope,
637+
request=request,
638+
receive=receive,
639+
send=send,
640+
writer=writer,
641+
message=message,
642+
request_id=request_id,
643+
request_stream_reader=request_stream_reader,
644+
)
572645
else: # pragma: no cover
573-
# Create SSE stream
574-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
575-
576-
# Store writer reference so close_sse_stream() can close it
577-
self._sse_stream_writers[request_id] = sse_stream_writer
578-
579-
async def sse_writer():
580-
# Get the request ID from the incoming request message
581-
try:
582-
async with sse_stream_writer, request_stream_reader:
583-
# Send priming event for SSE resumability
584-
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)
585-
586-
# Process messages from the request-specific stream
587-
async for event_message in request_stream_reader:
588-
# Build the event data
589-
event_data = self._create_event_data(event_message)
590-
await sse_stream_writer.send(event_data)
591-
592-
# If response, remove from pending streams and close
593-
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
594-
break
595-
except anyio.ClosedResourceError:
596-
# Expected when close_sse_stream() is called
597-
logger.debug("SSE stream closed by close_sse_stream()")
598-
except Exception:
599-
logger.exception("Error in SSE writer")
600-
finally:
601-
logger.debug("Closing SSE writer")
602-
self._sse_stream_writers.pop(request_id, None)
603-
await self._clean_up_memory_streams(request_id)
604-
605-
# Create and start EventSourceResponse
606-
# SSE stream mode (original behavior)
607-
# Set up headers
608-
headers = {
609-
"Cache-Control": "no-cache, no-transform",
610-
"Connection": "keep-alive",
611-
"Content-Type": CONTENT_TYPE_SSE,
612-
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
613-
}
614-
response = EventSourceResponse(
615-
content=sse_stream_reader,
616-
data_sender_callable=sse_writer,
617-
headers=headers,
646+
await self._handle_post_request_sse_mode(
647+
scope=scope,
648+
request=request,
649+
receive=receive,
650+
send=send,
651+
writer=writer,
652+
message=message,
653+
request_id=request_id,
654+
request_stream_reader=request_stream_reader,
655+
protocol_version=protocol_version,
618656
)
619657

620-
# Start the SSE response (this will send headers immediately)
621-
try:
622-
# First send the response to establish the SSE connection
623-
async with anyio.create_task_group() as tg:
624-
tg.start_soon(response, scope, receive, send)
625-
# Then send the message to be processed by the server
626-
session_message = self._create_session_message(message, request, request_id, protocol_version)
627-
await writer.send(session_message)
628-
except Exception:
629-
logger.exception("SSE response error")
630-
await sse_stream_writer.aclose()
631-
await sse_stream_reader.aclose()
632-
await self._clean_up_memory_streams(request_id)
633-
634658
except anyio.ClosedResourceError as err: # pragma: no cover
635659
# Session terminated (e.g., DELETE processed) while handling POST.
636660
# Response may have already been sent (e.g., 202 for notifications).

0 commit comments

Comments
 (0)