Skip to content

Commit ef7e4b4

Browse files
committed
Remove unnecessary pragmas and fix generic type propagation
- Remove `# pragma: no branch` and `@runtime_checkable` from stream protocols - Replace `_CreateContextStreams` with a proper generic class inheriting from `tuple[ContextSendStream[T], ContextReceiveStream[T]]`, matching anyio's `create_memory_object_stream` pattern so bracket syntax propagates types - Update `MessageStream` type alias to match actual stream creation
1 parent b321dd2 commit ef7e4b4

File tree

3 files changed

+25
-36
lines changed

3 files changed

+25
-36
lines changed

src/mcp/shared/_context_streams.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -103,26 +103,17 @@ async def __aexit__(
103103
return None
104104

105105

106-
def _create_context_streams(
107-
max_buffer_size: float = 0,
108-
) -> tuple[ContextSendStream[Any], ContextReceiveStream[Any]]:
109-
raw_send: MemoryObjectSendStream[Any]
110-
raw_receive: MemoryObjectReceiveStream[Any]
111-
raw_send, raw_receive = anyio.create_memory_object_stream(max_buffer_size)
112-
return ContextSendStream(raw_send), ContextReceiveStream(raw_receive)
106+
class create_context_streams(
107+
tuple[ContextSendStream[T], ContextReceiveStream[T]],
108+
):
109+
"""Create context-aware memory object streams.
113110
114-
115-
class _CreateContextStreams:
116-
"""Callable that supports ``create_context_streams[T](n)`` bracket syntax.
117-
118-
Matches anyio's ``create_memory_object_stream`` API style.
111+
Supports ``create_context_streams[T](n)`` bracket syntax,
112+
matching anyio's ``create_memory_object_stream`` API style.
119113
"""
120114

121-
def __getitem__(self, _item: Any) -> _CreateContextStreams:
122-
return self
123-
124-
def __call__(self, max_buffer_size: float = 0) -> tuple[ContextSendStream[Any], ContextReceiveStream[Any]]:
125-
return _create_context_streams(max_buffer_size)
126-
127-
128-
create_context_streams = _CreateContextStreams()
115+
def __new__(cls, max_buffer_size: float = 0) -> tuple[ContextSendStream[T], ContextReceiveStream[T]]: # type: ignore[type-var]
116+
raw_send: MemoryObjectSendStream[Any]
117+
raw_receive: MemoryObjectReceiveStream[Any]
118+
raw_send, raw_receive = anyio.create_memory_object_stream(max_buffer_size)
119+
return (ContextSendStream(raw_send), ContextReceiveStream(raw_receive))

src/mcp/shared/_stream_protocols.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,41 @@
77
from __future__ import annotations
88

99
from types import TracebackType
10-
from typing import Protocol, TypeVar, runtime_checkable
10+
from typing import Protocol, TypeVar
1111

1212
from typing_extensions import Self
1313

1414
T_co = TypeVar("T_co", covariant=True)
1515
T_contra = TypeVar("T_contra", contravariant=True)
1616

1717

18-
@runtime_checkable
19-
class ReadStream(Protocol[T_co]): # pragma: no branch
18+
class ReadStream(Protocol[T_co]):
2019
"""Protocol for reading items from a stream.
2120
2221
Consumers that need the sender's context should use
2322
``getattr(stream, 'last_context', None)``.
2423
"""
2524

26-
async def receive(self) -> T_co: ... # pragma: no branch
27-
async def aclose(self) -> None: ... # pragma: no branch
28-
def __aiter__(self) -> ReadStream[T_co]: ... # pragma: no branch
29-
async def __anext__(self) -> T_co: ... # pragma: no branch
30-
async def __aenter__(self) -> Self: ... # pragma: no branch
31-
async def __aexit__( # pragma: no branch
25+
async def receive(self) -> T_co: ...
26+
async def aclose(self) -> None: ...
27+
def __aiter__(self) -> ReadStream[T_co]: ...
28+
async def __anext__(self) -> T_co: ...
29+
async def __aenter__(self) -> Self: ...
30+
async def __aexit__(
3231
self,
3332
exc_type: type[BaseException] | None,
3433
exc_val: BaseException | None,
3534
exc_tb: TracebackType | None,
3635
) -> bool | None: ...
3736

3837

39-
@runtime_checkable
40-
class WriteStream(Protocol[T_contra]): # pragma: no branch
38+
class WriteStream(Protocol[T_contra]):
4139
"""Protocol for writing items to a stream."""
4240

43-
async def send(self, item: T_contra, /) -> None: ... # pragma: no branch
44-
async def aclose(self) -> None: ... # pragma: no branch
45-
async def __aenter__(self) -> Self: ... # pragma: no branch
46-
async def __aexit__( # pragma: no branch
41+
async def send(self, item: T_contra, /) -> None: ...
42+
async def aclose(self) -> None: ...
43+
async def __aenter__(self) -> Self: ...
44+
async def __aexit__(
4745
self,
4846
exc_type: type[BaseException] | None,
4947
exc_val: BaseException | None,

src/mcp/shared/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
99
from mcp.shared.message import SessionMessage
1010

11-
MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage]]
11+
MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]]
1212

1313

1414
@asynccontextmanager

0 commit comments

Comments
 (0)