|
7 | 7 | from collections.abc import AsyncGenerator, Awaitable, Callable |
8 | 8 | from contextlib import asynccontextmanager |
9 | 9 | from dataclasses import dataclass |
10 | | -from types import TracebackType |
11 | 10 |
|
12 | 11 | import anyio |
13 | 12 | import httpx |
|
18 | 17 | from mcp.client._transport import TransportStreams |
19 | 18 | from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams |
20 | 19 | from mcp.shared._httpx_utils import create_mcp_http_client |
21 | | -from mcp.shared._stream_protocols import WriteStream |
22 | 20 | from mcp.shared.message import ClientMessageMetadata, SessionMessage |
23 | 21 | from mcp.types import ( |
24 | 22 | INTERNAL_ERROR, |
@@ -514,35 +512,6 @@ def get_session_id(self) -> str | None: |
514 | 512 | return self.session_id # pragma: no cover |
515 | 513 |
|
516 | 514 |
|
517 | | -class _SessionAwareWriteStream: |
518 | | - """Write-stream wrapper that exposes the transport session ID.""" |
519 | | - |
520 | | - def __init__(self, inner: WriteStream[SessionMessage], transport: StreamableHTTPTransport) -> None: |
521 | | - self._inner = inner |
522 | | - self._transport = transport |
523 | | - |
524 | | - async def send(self, item: SessionMessage) -> None: |
525 | | - await self._inner.send(item) |
526 | | - |
527 | | - async def aclose(self) -> None: |
528 | | - await self._inner.aclose() |
529 | | - |
530 | | - def get_session_id(self) -> str | None: |
531 | | - return self._transport.session_id |
532 | | - |
533 | | - async def __aenter__(self) -> _SessionAwareWriteStream: |
534 | | - await self._inner.__aenter__() |
535 | | - return self |
536 | | - |
537 | | - async def __aexit__( |
538 | | - self, |
539 | | - exc_type: type[BaseException] | None, |
540 | | - exc_val: BaseException | None, |
541 | | - exc_tb: TracebackType | None, |
542 | | - ) -> bool | None: |
543 | | - return await self._inner.__aexit__(exc_type, exc_val, exc_tb) |
544 | | - |
545 | | - |
546 | 515 | # TODO(Marcelo): I've dropped the `get_session_id` callback because it breaks the Transport protocol. Is that needed? |
547 | 516 | # It's a completely wrong abstraction, so removal is a good idea. But if we need the client to find the session ID, |
548 | 517 | # we should think about a better way to do it. I believe we can achieve it with other means. |
@@ -612,7 +581,7 @@ def start_get_stream() -> None: |
612 | 581 | ) |
613 | 582 |
|
614 | 583 | try: |
615 | | - yield read_stream, _SessionAwareWriteStream(write_stream, transport) |
| 584 | + yield read_stream, write_stream |
616 | 585 | finally: |
617 | 586 | if transport.session_id and terminate_on_close: |
618 | 587 | await transport.terminate_session(client) |
|
0 commit comments