Skip to content

Commit 6237bc3

Browse files
authored
Merge branch 'main' into fix/stateless-task-group-leak
2 parents bb1a218 + d5b9155 commit 6237bc3

File tree

33 files changed

+874
-229
lines changed

33 files changed

+874
-229
lines changed

.github/workflows/claude-code-review.yml

Lines changed: 0 additions & 33 deletions
This file was deleted.

.github/workflows/claude.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ on:
1414
jobs:
1515
claude:
1616
if: |
17-
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
17+
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude') && !startsWith(github.event.comment.body, '@claude review')) ||
1818
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
1919
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
2020
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from urllib.parse import parse_qs, urlparse
1919

2020
import httpx
21-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
21+
from mcp.client._transport import ReadStream, WriteStream
2222
from mcp.client.auth import OAuthClientProvider, TokenStorage
2323
from mcp.client.session import ClientSession
2424
from mcp.client.sse import sse_client
@@ -241,8 +241,8 @@ async def _default_redirect_handler(authorization_url: str) -> None:
241241

242242
async def _run_session(
243243
self,
244-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
245-
write_stream: MemoryObjectSendStream[SessionMessage],
244+
read_stream: ReadStream[SessionMessage | Exception],
245+
write_stream: WriteStream[SessionMessage],
246246
):
247247
"""Run the MCP session with the given streams."""
248248
print("🤝 Initializing MCP session...")

pyproject.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
]
2727
dependencies = [
2828
"anyio>=4.9",
29-
"httpx>=0.27.1",
29+
"httpx>=0.27.1,<1.0.0",
3030
"httpx-sse>=0.4",
3131
"pydantic>=2.12.0",
3232
"starlette>=0.48.0; python_version >= '3.14'",
@@ -40,6 +40,7 @@ dependencies = [
4040
"pyjwt[crypto]>=2.10.1",
4141
"typing-extensions>=4.13.0",
4242
"typing-inspection>=0.4.1",
43+
"opentelemetry-api>=1.28.0",
4344
]
4445

4546
[project.optional-dependencies]
@@ -71,6 +72,7 @@ dev = [
7172
"coverage[toml]>=7.10.7,<=7.13",
7273
"pillow>=12.0",
7374
"strict-no-cover",
75+
"logfire>=3.0.0",
7476
]
7577
docs = [
7678
"mkdocs>=1.6.1",
@@ -219,13 +221,10 @@ skip_covered = true
219221
show_missing = true
220222
ignore_errors = true
221223
precision = 2
222-
exclude_lines = [
223-
"pragma: no cover",
224+
exclude_also = [
224225
"pragma: lax no cover",
225-
"if TYPE_CHECKING:",
226226
"@overload",
227227
"raise NotImplementedError",
228-
"^\\s*\\.\\.\\.\\s*$",
229228
]
230229

231230
# https://coverage.readthedocs.io/en/latest/config.html#paths

src/mcp/client/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from urllib.parse import urlparse
77

88
import anyio
9-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
109

1110
from mcp import types
11+
from mcp.client._transport import ReadStream, WriteStream
1212
from mcp.client.session import ClientSession
1313
from mcp.client.sse import sse_client
1414
from mcp.client.stdio import StdioServerParameters, stdio_client
@@ -33,8 +33,8 @@ async def message_handler(
3333

3434

3535
async def run_session(
36-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
37-
write_stream: MemoryObjectSendStream[SessionMessage],
36+
read_stream: ReadStream[SessionMessage | Exception],
37+
write_stream: WriteStream[SessionMessage],
3838
client_info: types.Implementation | None = None,
3939
):
4040
async with ClientSession(

src/mcp/client/_transport.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from contextlib import AbstractAsyncContextManager
66
from typing import Protocol
77

8-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
9-
8+
from mcp.shared._stream_protocols import ReadStream, WriteStream
109
from mcp.shared.message import SessionMessage
1110

12-
TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
11+
__all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"]
12+
13+
TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]]
1314

1415

1516
class Transport(AbstractAsyncContextManager[TransportStreams], Protocol):

src/mcp/client/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from typing import Any, Protocol
55

66
import anyio.lowlevel
7-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
87
from pydantic import TypeAdapter
98

109
from mcp import types
10+
from mcp.client._transport import ReadStream, WriteStream
1111
from mcp.client.experimental import ExperimentalClientFeatures
1212
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1313
from mcp.shared._context import RequestContext
@@ -109,8 +109,8 @@ class ClientSession(
109109
):
110110
def __init__(
111111
self,
112-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
113-
write_stream: MemoryObjectSendStream[SessionMessage],
112+
read_stream: ReadStream[SessionMessage | Exception],
113+
write_stream: WriteStream[SessionMessage],
114114
read_timeout_seconds: float | None = None,
115115
sampling_callback: SamplingFnT | None = None,
116116
elicitation_callback: ElicitationFnT | None = None,

src/mcp/client/sse.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import anyio
88
import httpx
99
from anyio.abc import TaskStatus
10-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1110
from httpx_sse import aconnect_sse
1211
from httpx_sse._exceptions import SSEError
1312

1413
from mcp import types
14+
from mcp.shared._context_streams import create_context_streams
1515
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1616
from mcp.shared.message import SessionMessage
1717

@@ -51,12 +51,6 @@ async def sse_client(
5151
auth: Optional HTTPX authentication handler.
5252
on_session_created: Optional callback invoked with the session ID when received.
5353
"""
54-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
55-
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
56-
57-
write_stream: MemoryObjectSendStream[SessionMessage]
58-
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
59-
6054
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
6155
async with httpx_client_factory(
6256
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
@@ -65,8 +59,8 @@ async def sse_client(
6559
event_source.response.raise_for_status()
6660
logger.debug("SSE connection established")
6761

68-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
69-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
62+
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
63+
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
7064

7165
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
7266
try:
@@ -124,7 +118,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
124118
async def post_writer(endpoint_url: str):
125119
try:
126120
async with write_stream_reader, write_stream:
127-
async for session_message in write_stream_reader:
121+
122+
async def _send_message(session_message: SessionMessage) -> None:
128123
logger.debug(f"Sending client message: {session_message}")
129124
response = await client.post(
130125
endpoint_url,
@@ -136,6 +131,14 @@ async def post_writer(endpoint_url: str):
136131
)
137132
response.raise_for_status()
138133
logger.debug(f"Client message sent successfully: {response.status_code}")
134+
135+
async for session_message in write_stream_reader:
136+
sender_ctx = write_stream_reader.last_context
137+
if sender_ctx is not None:
138+
async with anyio.create_task_group() as tg:
139+
sender_ctx.run(tg.start_soon, _send_message, session_message)
140+
else:
141+
await _send_message(session_message) # pragma: no cover
139142
except Exception: # pragma: lax no cover
140143
logger.exception("Error in post_writer")
141144

src/mcp/client/streamable_http.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import anyio
1212
import httpx
1313
from anyio.abc import TaskGroup
14-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1514
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
1615
from pydantic import ValidationError
1716

1817
from mcp.client._transport import TransportStreams
18+
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
1919
from mcp.shared._httpx_utils import create_mcp_http_client
2020
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2121
from mcp.types import (
@@ -38,8 +38,8 @@
3838

3939
# TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here.
4040
SessionMessageOrError = SessionMessage | Exception
41-
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
42-
StreamReader = MemoryObjectReceiveStream[SessionMessage]
41+
StreamWriter = ContextSendStream[SessionMessageOrError]
42+
StreamReader = ContextReceiveStream[SessionMessage]
4343

4444
MCP_SESSION_ID = "mcp-session-id"
4545
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
@@ -434,14 +434,15 @@ async def post_writer(
434434
client: httpx.AsyncClient,
435435
write_stream_reader: StreamReader,
436436
read_stream_writer: StreamWriter,
437-
write_stream: MemoryObjectSendStream[SessionMessage],
437+
write_stream: ContextSendStream[SessionMessage],
438438
start_get_stream: Callable[[], None],
439439
tg: TaskGroup,
440440
) -> None:
441441
"""Handle writing requests to the server."""
442442
try:
443443
async with write_stream_reader, read_stream_writer, write_stream:
444-
async for session_message in write_stream_reader:
444+
445+
async def _handle_message(session_message: SessionMessage) -> None:
445446
message = session_message.message
446447
metadata = (
447448
session_message.metadata
@@ -478,6 +479,14 @@ async def handle_request_async():
478479
else:
479480
await handle_request_async()
480481

482+
async for session_message in write_stream_reader:
483+
sender_ctx = write_stream_reader.last_context
484+
if sender_ctx is not None:
485+
async with anyio.create_task_group() as tg_local:
486+
sender_ctx.run(tg_local.start_soon, _handle_message, session_message)
487+
else:
488+
await _handle_message(session_message) # pragma: no cover
489+
481490
except Exception: # pragma: lax no cover
482491
logger.exception("Error in post_writer")
483492

@@ -547,8 +556,8 @@ async def streamable_http_client(
547556
if not client_provided:
548557
await stack.enter_async_context(client)
549558

550-
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
551-
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
559+
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
560+
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
552561

553562
async with (
554563
read_stream_writer,

0 commit comments

Comments
 (0)