Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/claude_agent_sdk/_internal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,31 @@ async def process_query(
"parent_tool_use_id": None,
}
await chosen_transport.write(json.dumps(user_message) + "\n")
await chosen_transport.end_input()

has_bidirectional_needs = bool(sdk_mcp_servers) or bool(
configured_options.hooks
)
if not has_bidirectional_needs:
# No bidirectional control protocol needed: close stdin immediately
await chosen_transport.end_input()
elif query._tg:
# Defer stdin close until the conversation ends (result message).
# The CLI needs stdin open for the entire conversation to send
# tools/list, tools/call, and hook callbacks via control protocol.
# No timeout needed for string prompts — the result message
# always arrives when the CLI finishes, and task group
# cancellation triggers the finally block on abnormal exit.
async def _deferred_end_input() -> None:
try:
await query._first_result_event.wait()
finally:
await chosen_transport.end_input()

query._tg.start_soon(_deferred_end_input)
else:
# _tg should always exist after start(), but close stdin
# defensively to prevent resource leaks
await chosen_transport.end_input()
elif isinstance(prompt, AsyncIterable) and query._tg:
# Stream input in background for async iterables
query._tg.start_soon(query.stream_input, prompt)
Expand Down
2 changes: 2 additions & 0 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ async def initialize(self) -> dict[str, Any] | None:
}
if self._agents:
request["agents"] = self._agents
if self.sdk_mcp_servers:
request["sdkMcpServers"] = list(self.sdk_mcp_servers.keys())

# Use longer timeout for initialize since MCP servers may take time to start
response = await self._send_control_request(
Expand Down
1 change: 1 addition & 0 deletions src/claude_agent_sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ class SDKControlInitializeRequest(TypedDict):
subtype: Literal["initialize"]
hooks: dict[HookEvent, Any] | None
agents: NotRequired[dict[str, dict[str, Any]]]
sdkMcpServers: NotRequired[list[str]]


class SDKControlSetPermissionModeRequest(TypedDict):
Expand Down
220 changes: 220 additions & 0 deletions tests/test_deferred_end_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Tests for deferred end_input() with SDK MCP servers on string prompts.
When SDK MCP servers or hooks are present and prompt is a string,
end_input() must be deferred until the first result message is received.
Closing stdin immediately prevents the CLI from completing tools/list
via control protocol, making SDK MCP tools invisible to the model.
See: client.py process_query() string prompt handling.
"""

from unittest.mock import AsyncMock, Mock, patch

import anyio

from claude_agent_sdk import ClaudeAgentOptions, query

RESULT_MESSAGE = {
"type": "result",
"subtype": "success",
"duration_ms": 100,
"duration_api_ms": 80,
"is_error": False,
"num_turns": 1,
"session_id": "test",
"total_cost_usd": 0.001,
}


def _make_mock_transport() -> Mock:
"""Create a mock transport with standard async methods."""
mock_transport = Mock()
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)

async def mock_receive() -> None: # type: ignore[return]
yield RESULT_MESSAGE

mock_transport.read_messages = mock_receive
return mock_transport


class TestDeferredEndInput:
"""Test that end_input() is deferred for string prompts with SDK MCP servers."""

def test_string_prompt_without_sdk_mcp_servers_closes_stdin_immediately(
self,
) -> None:
"""Without SDK MCP servers, end_input() should be called right after write."""

async def _test() -> None:
mock_transport = _make_mock_transport()

# Track call order
call_order: list[str] = []
original_write = mock_transport.write

async def tracking_write(data: str) -> None:
call_order.append("write")
return await original_write(data)

async def tracking_end_input() -> None:
call_order.append("end_input")

mock_transport.write = AsyncMock(side_effect=tracking_write)
mock_transport.end_input = AsyncMock(side_effect=tracking_end_input)

with (
patch(
"claude_agent_sdk._internal.client.SubprocessCLITransport",
return_value=mock_transport,
),
patch(
"claude_agent_sdk._internal.query.Query.initialize",
new_callable=AsyncMock,
),
):
options = ClaudeAgentOptions()
async for _ in query(prompt="test prompt", options=options):
pass

# end_input should be called immediately after write
assert "write" in call_order
assert "end_input" in call_order
write_idx = call_order.index("write")
end_input_idx = call_order.index("end_input")
assert end_input_idx == write_idx + 1, (
f"end_input should follow write immediately, got order: {call_order}"
)

anyio.run(_test)

def test_string_prompt_with_sdk_mcp_servers_calls_end_input(self) -> None:
"""With SDK MCP servers, end_input() must be called (deferred via task group)."""

async def _test() -> None:
mock_transport = _make_mock_transport()
mock_mcp_server = Mock()

with (
patch(
"claude_agent_sdk._internal.client.SubprocessCLITransport",
return_value=mock_transport,
),
patch(
"claude_agent_sdk._internal.query.Query.initialize",
new_callable=AsyncMock,
),
):
options = ClaudeAgentOptions(
mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item]
)
async for _ in query(prompt="test prompt", options=options):
pass

mock_transport.end_input.assert_called_once()

anyio.run(_test)

def test_deferred_end_input_waits_for_result_event(self) -> None:
"""end_input() must not be called before the result event fires.
Uses a delayed result message to verify that end_input waits
for _first_result_event rather than closing stdin immediately.
"""

async def _test() -> None:
mock_transport = _make_mock_transport()
end_input_called_before_result = False

# Override read_messages to delay the result
result_gate = anyio.Event()

async def delayed_receive() -> None: # type: ignore[return]
await result_gate.wait()
yield RESULT_MESSAGE

mock_transport.read_messages = delayed_receive

original_end_input = mock_transport.end_input

async def tracking_end_input() -> None:
nonlocal end_input_called_before_result
if not result_gate.is_set():
end_input_called_before_result = True
return await original_end_input()

mock_transport.end_input = AsyncMock(side_effect=tracking_end_input)

mock_mcp_server = Mock()

with (
patch(
"claude_agent_sdk._internal.client.SubprocessCLITransport",
return_value=mock_transport,
),
patch(
"claude_agent_sdk._internal.query.Query.initialize",
new_callable=AsyncMock,
),
):
options = ClaudeAgentOptions(
mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item]
)

async def consume_and_release() -> None:
# Give task group time to start _deferred_end_input
await anyio.sleep(0.05)
# Now release the result — end_input should NOT have been called yet
result_gate.set()

async with anyio.create_task_group() as tg:
tg.start_soon(consume_and_release)
async for _ in query(prompt="test prompt", options=options):
pass

assert not end_input_called_before_result, (
"end_input must not be called before the result event fires"
)
mock_transport.end_input.assert_called_once()

anyio.run(_test)

def test_end_input_called_even_with_sdk_mcp_servers(self) -> None:
"""end_input() must always eventually be called to avoid resource leaks."""

async def _test() -> None:
end_input_called = anyio.Event()

async def tracking_end_input() -> None:
end_input_called.set()

mock_transport = _make_mock_transport()
mock_transport.end_input = AsyncMock(side_effect=tracking_end_input)

mock_mcp_server = Mock()

with (
patch(
"claude_agent_sdk._internal.client.SubprocessCLITransport",
return_value=mock_transport,
),
patch(
"claude_agent_sdk._internal.query.Query.initialize",
new_callable=AsyncMock,
),
):
options = ClaudeAgentOptions(
mcp_servers={"team": {"type": "sdk", "instance": mock_mcp_server}}, # type: ignore[typeddict-item]
)
async for _ in query(prompt="test prompt", options=options):
pass

assert end_input_called.is_set(), (
"end_input must be called even with SDK MCP servers (deferred, not skipped)"
)

anyio.run(_test)
Loading