|
17 | 17 |
|
18 | 18 | import anyio |
19 | 19 | import pydantic_core |
| 20 | +from anyio.abc import ObjectReceiveStream, ObjectSendStream |
20 | 21 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
21 | 22 | from pydantic import ValidationError |
22 | 23 | from sse_starlette import EventSourceResponse |
@@ -427,6 +428,110 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se |
427 | 428 | return False |
428 | 429 | return True |
429 | 430 |
|
| 431 | + async def _handle_post_request_json_mode( |
| 432 | + self, |
| 433 | + *, |
| 434 | + scope: Scope, |
| 435 | + request: Request, |
| 436 | + receive: Receive, |
| 437 | + send: Send, |
| 438 | + writer: ObjectSendStream[SessionMessage], |
| 439 | + message: JSONRPCRequest, |
| 440 | + request_id: str, |
| 441 | + request_stream_reader: ObjectReceiveStream[EventMessage], |
| 442 | + ) -> None: |
| 443 | + metadata = ServerMessageMetadata(request_context=request) |
| 444 | + session_message = SessionMessage(message, metadata=metadata) |
| 445 | + await writer.send(session_message) |
| 446 | + try: |
| 447 | + # Process messages from the request-specific stream. |
| 448 | + response_message: JSONRPCResponse | JSONRPCError | None = None |
| 449 | + |
| 450 | + async for event_message in request_stream_reader: # pragma: no branch |
| 451 | + if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): |
| 452 | + response_message = event_message.message |
| 453 | + break |
| 454 | + else: # pragma: no cover |
| 455 | + logger.debug("received: %s", event_message.message.method) |
| 456 | + |
| 457 | + if response_message: |
| 458 | + response = self._create_json_response(response_message) |
| 459 | + await response(scope, receive, send) |
| 460 | + else: # pragma: no cover |
| 461 | + logger.error("No response message received before stream closed") |
| 462 | + response = self._create_error_response( |
| 463 | + "Error processing request: No response received", |
| 464 | + HTTPStatus.INTERNAL_SERVER_ERROR, |
| 465 | + ) |
| 466 | + await response(scope, receive, send) |
| 467 | + except Exception: # pragma: no cover |
| 468 | + logger.exception("Error processing JSON response") |
| 469 | + response = self._create_error_response( |
| 470 | + "Error processing request", |
| 471 | + HTTPStatus.INTERNAL_SERVER_ERROR, |
| 472 | + INTERNAL_ERROR, |
| 473 | + ) |
| 474 | + await response(scope, receive, send) |
| 475 | + finally: |
| 476 | + await self._clean_up_memory_streams(request_id) |
| 477 | + |
| 478 | + async def _handle_post_request_sse_mode( |
| 479 | + self, |
| 480 | + *, |
| 481 | + scope: Scope, |
| 482 | + request: Request, |
| 483 | + receive: Receive, |
| 484 | + send: Send, |
| 485 | + writer: ObjectSendStream[SessionMessage], |
| 486 | + message: JSONRPCRequest, |
| 487 | + request_id: str, |
| 488 | + request_stream_reader: ObjectReceiveStream[EventMessage], |
| 489 | + protocol_version: str, |
| 490 | + ) -> None: # pragma: no cover |
| 491 | + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) |
| 492 | + self._sse_stream_writers[request_id] = sse_stream_writer |
| 493 | + |
| 494 | + async def sse_writer() -> None: |
| 495 | + try: |
| 496 | + async with sse_stream_writer, request_stream_reader: |
| 497 | + await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version) |
| 498 | + async for event_message in request_stream_reader: |
| 499 | + event_data = self._create_event_data(event_message) |
| 500 | + await sse_stream_writer.send(event_data) |
| 501 | + if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): |
| 502 | + break |
| 503 | + except anyio.ClosedResourceError: |
| 504 | + logger.debug("SSE stream closed by close_sse_stream()") |
| 505 | + except Exception: |
| 506 | + logger.exception("Error in SSE writer") |
| 507 | + finally: |
| 508 | + logger.debug("Closing SSE writer") |
| 509 | + self._sse_stream_writers.pop(request_id, None) |
| 510 | + await self._clean_up_memory_streams(request_id) |
| 511 | + |
| 512 | + headers = { |
| 513 | + "Cache-Control": "no-cache, no-transform", |
| 514 | + "Connection": "keep-alive", |
| 515 | + "Content-Type": CONTENT_TYPE_SSE, |
| 516 | + **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}), |
| 517 | + } |
| 518 | + response = EventSourceResponse( |
| 519 | + content=sse_stream_reader, |
| 520 | + data_sender_callable=sse_writer, |
| 521 | + headers=headers, |
| 522 | + ) |
| 523 | + |
| 524 | + try: |
| 525 | + async with anyio.create_task_group() as tg: |
| 526 | + tg.start_soon(response, scope, receive, send) |
| 527 | + session_message = self._create_session_message(message, request, request_id, protocol_version) |
| 528 | + await writer.send(session_message) |
| 529 | + except Exception: |
| 530 | + logger.exception("SSE response error") |
| 531 | + await sse_stream_writer.aclose() |
| 532 | + await sse_stream_reader.aclose() |
| 533 | + await self._clean_up_memory_streams(request_id) |
| 534 | + |
430 | 535 | async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: |
431 | 536 | """Handle POST requests containing JSON-RPC messages.""" |
432 | 537 | writer = self._read_stream_writer |
@@ -527,110 +632,29 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re |
527 | 632 | request_stream_reader = self._request_streams[request_id][1] |
528 | 633 |
|
529 | 634 | if self.is_json_response_enabled: |
530 | | - # Process the message |
531 | | - metadata = ServerMessageMetadata(request_context=request) |
532 | | - session_message = SessionMessage(message, metadata=metadata) |
533 | | - await writer.send(session_message) |
534 | | - try: |
535 | | - # Process messages from the request-specific stream |
536 | | - # We need to collect all messages until we get a response |
537 | | - response_message = None |
538 | | - |
539 | | - # Use similar approach to SSE writer for consistency |
540 | | - async for event_message in request_stream_reader: # pragma: no branch |
541 | | - # If it's a response, this is what we're waiting for |
542 | | - if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): |
543 | | - response_message = event_message.message |
544 | | - break |
545 | | - # For notifications and request, keep waiting |
546 | | - else: # pragma: no cover |
547 | | - logger.debug(f"received: {event_message.message.method}") |
548 | | - |
549 | | - # At this point we should have a response |
550 | | - if response_message: |
551 | | - # Create JSON response |
552 | | - response = self._create_json_response(response_message) |
553 | | - await response(scope, receive, send) |
554 | | - else: # pragma: no cover |
555 | | - # This shouldn't happen in normal operation |
556 | | - logger.error("No response message received before stream closed") |
557 | | - response = self._create_error_response( |
558 | | - "Error processing request: No response received", |
559 | | - HTTPStatus.INTERNAL_SERVER_ERROR, |
560 | | - ) |
561 | | - await response(scope, receive, send) |
562 | | - except Exception: # pragma: no cover |
563 | | - logger.exception("Error processing JSON response") |
564 | | - response = self._create_error_response( |
565 | | - "Error processing request", |
566 | | - HTTPStatus.INTERNAL_SERVER_ERROR, |
567 | | - INTERNAL_ERROR, |
568 | | - ) |
569 | | - await response(scope, receive, send) |
570 | | - finally: |
571 | | - await self._clean_up_memory_streams(request_id) |
| 635 | + await self._handle_post_request_json_mode( |
| 636 | + scope=scope, |
| 637 | + request=request, |
| 638 | + receive=receive, |
| 639 | + send=send, |
| 640 | + writer=writer, |
| 641 | + message=message, |
| 642 | + request_id=request_id, |
| 643 | + request_stream_reader=request_stream_reader, |
| 644 | + ) |
572 | 645 | else: # pragma: no cover |
573 | | - # Create SSE stream |
574 | | - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) |
575 | | - |
576 | | - # Store writer reference so close_sse_stream() can close it |
577 | | - self._sse_stream_writers[request_id] = sse_stream_writer |
578 | | - |
579 | | - async def sse_writer(): |
580 | | - # Get the request ID from the incoming request message |
581 | | - try: |
582 | | - async with sse_stream_writer, request_stream_reader: |
583 | | - # Send priming event for SSE resumability |
584 | | - await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version) |
585 | | - |
586 | | - # Process messages from the request-specific stream |
587 | | - async for event_message in request_stream_reader: |
588 | | - # Build the event data |
589 | | - event_data = self._create_event_data(event_message) |
590 | | - await sse_stream_writer.send(event_data) |
591 | | - |
592 | | - # If response, remove from pending streams and close |
593 | | - if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): |
594 | | - break |
595 | | - except anyio.ClosedResourceError: |
596 | | - # Expected when close_sse_stream() is called |
597 | | - logger.debug("SSE stream closed by close_sse_stream()") |
598 | | - except Exception: |
599 | | - logger.exception("Error in SSE writer") |
600 | | - finally: |
601 | | - logger.debug("Closing SSE writer") |
602 | | - self._sse_stream_writers.pop(request_id, None) |
603 | | - await self._clean_up_memory_streams(request_id) |
604 | | - |
605 | | - # Create and start EventSourceResponse |
606 | | - # SSE stream mode (original behavior) |
607 | | - # Set up headers |
608 | | - headers = { |
609 | | - "Cache-Control": "no-cache, no-transform", |
610 | | - "Connection": "keep-alive", |
611 | | - "Content-Type": CONTENT_TYPE_SSE, |
612 | | - **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}), |
613 | | - } |
614 | | - response = EventSourceResponse( |
615 | | - content=sse_stream_reader, |
616 | | - data_sender_callable=sse_writer, |
617 | | - headers=headers, |
| 646 | + await self._handle_post_request_sse_mode( |
| 647 | + scope=scope, |
| 648 | + request=request, |
| 649 | + receive=receive, |
| 650 | + send=send, |
| 651 | + writer=writer, |
| 652 | + message=message, |
| 653 | + request_id=request_id, |
| 654 | + request_stream_reader=request_stream_reader, |
| 655 | + protocol_version=protocol_version, |
618 | 656 | ) |
619 | 657 |
|
620 | | - # Start the SSE response (this will send headers immediately) |
621 | | - try: |
622 | | - # First send the response to establish the SSE connection |
623 | | - async with anyio.create_task_group() as tg: |
624 | | - tg.start_soon(response, scope, receive, send) |
625 | | - # Then send the message to be processed by the server |
626 | | - session_message = self._create_session_message(message, request, request_id, protocol_version) |
627 | | - await writer.send(session_message) |
628 | | - except Exception: |
629 | | - logger.exception("SSE response error") |
630 | | - await sse_stream_writer.aclose() |
631 | | - await sse_stream_reader.aclose() |
632 | | - await self._clean_up_memory_streams(request_id) |
633 | | - |
634 | 658 | except anyio.ClosedResourceError as err: # pragma: no cover |
635 | 659 | # Session terminated (e.g., DELETE processed) while handling POST. |
636 | 660 | # Response may have already been sent (e.g., 202 for notifications). |
|
0 commit comments