Skip to content

Commit 78e0030

Browse files
committed
fix(app-server): surface turn errors and process events
1 parent e14682c commit 78e0030

5 files changed

Lines changed: 223 additions & 9 deletions

File tree

codex/app_server/_async_client.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from codex.app_server._async_threads import AsyncAppServerThread as AsyncAppServerThread
2323
from codex.app_server._async_threads import AsyncTurnStream as AsyncTurnStream
2424
from codex.app_server._async_threads import _ThreadClient
25-
from codex.app_server._protocol_helpers import RequestHandler
25+
from codex.app_server._protocol_helpers import Notification, RequestHandler
2626
from codex.app_server._session import _AsyncNotificationSubscription, _AsyncSession
2727
from codex.app_server.models import (
2828
InitializeResult,
@@ -107,6 +107,35 @@ def __init__(self, session: _AsyncSession) -> None:
107107
def subscribe(self, methods: Collection[str] | None = None) -> _AsyncNotificationSubscription:
108108
return self._session.subscribe_notifications(methods)
109109

110+
def subscribe_command_exec_output(self, process_id: str) -> _AsyncNotificationSubscription:
111+
"""Subscribe to `command/exec/outputDelta` notifications for one process id."""
112+
113+
def predicate(notification: Notification) -> bool:
114+
return (
115+
isinstance(notification, protocol.CommandExecOutputDeltaNotificationModel)
116+
and notification.params.processId == process_id
117+
)
118+
119+
return self._session.subscribe_notifications(
120+
{"command/exec/outputDelta"},
121+
predicate=predicate,
122+
)
123+
124+
def subscribe_process_events(self, process_handle: str) -> _AsyncNotificationSubscription:
125+
"""Subscribe to `process/outputDelta` and `process/exited` for one process handle."""
126+
127+
def predicate(notification: Notification) -> bool:
128+
if isinstance(notification, protocol.ProcessOutputDeltaNotificationModel):
129+
return notification.params.processHandle == process_handle
130+
if isinstance(notification, protocol.ProcessExitedNotificationModel):
131+
return notification.params.processHandle == process_handle
132+
return False
133+
134+
return self._session.subscribe_notifications(
135+
{"process/outputDelta", "process/exited"},
136+
predicate=predicate,
137+
)
138+
110139

111140
class AsyncAppServerClient:
112141
"""Async client for `codex app-server`."""

codex/app_server/_async_threads.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
DEFAULT_REVIEW_DELIVERY = protocol.ReviewDelivery("inline")
3838

3939
_TURN_STREAM_NOTIFICATION_METHODS = {
40+
"error",
4041
"turn/started",
4142
"turn/completed",
4243
"turn/diff/updated",
@@ -54,6 +55,7 @@
5455
"item/reasoning/summaryPartAdded",
5556
"item/reasoning/textDelta",
5657
"item/commandExecution/outputDelta",
58+
"item/commandExecution/terminalInteraction",
5759
"item/fileChange/outputDelta",
5860
"serverRequest/resolved",
5961
}
@@ -188,6 +190,15 @@ async def __anext__(self) -> Notification:
188190
raise StopAsyncIteration
189191
notification = await self._subscription.next()
190192
self._apply(notification)
193+
if isinstance(notification, protocol.ErrorNotificationModel):
194+
if not notification.params.willRetry:
195+
self._done = True
196+
await self.close()
197+
error = notification.params.error
198+
message = error.message
199+
if error.additionalDetails is not None and error.additionalDetails != "":
200+
message = f"{message}: {error.additionalDetails}"
201+
raise AppServerTurnError(message)
191202
if isinstance(notification, protocol.TurnCompletedNotificationModel):
192203
self._done = True
193204
return notification

codex/app_server/_sync_threads.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def subscribe(
2525
methods: Collection[str] | None = None,
2626
) -> _AsyncNotificationSubscription: ...
2727

28+
def subscribe_command_exec_output(self, process_id: str) -> _AsyncNotificationSubscription: ...
29+
30+
def subscribe_process_events(self, process_handle: str) -> _AsyncNotificationSubscription: ...
31+
2832

2933
class _AsyncTurnStreamLike(Protocol):
3034
initial_turn: protocol.Turn
@@ -157,6 +161,20 @@ def subscribe(self, methods: Collection[str] | None = None) -> NotificationSubsc
157161
self._run,
158162
)
159163

164+
def subscribe_command_exec_output(self, process_id: str) -> NotificationSubscription:
165+
"""Subscribe to `command/exec/outputDelta` notifications for one process id."""
166+
return NotificationSubscription(
167+
self._async_events.subscribe_command_exec_output(process_id),
168+
self._run,
169+
)
170+
171+
def subscribe_process_events(self, process_handle: str) -> NotificationSubscription:
172+
"""Subscribe to `process/outputDelta` and `process/exited` for one process handle."""
173+
return NotificationSubscription(
174+
self._async_events.subscribe_process_events(process_handle),
175+
self._run,
176+
)
177+
160178

161179
class TurnStream(_SyncRunner):
162180
"""Synchronous iterator over protocol-native notifications for a single turn."""

tests/test_app_server_async_client.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from codex.app_server._async_client import AsyncEventsClient, AsyncTurnStream
8-
from codex.app_server.errors import AppServerProtocolError
8+
from codex.app_server.errors import AppServerProtocolError, AppServerTurnError
99
from codex.app_server.models import ReviewResult
1010
from codex.protocol import types as protocol
1111

@@ -34,6 +34,17 @@ def update_predicate(self, predicate: object) -> None:
3434
self.updated_predicate = predicate
3535

3636

37+
class _QueuedSubscription(_FakeSubscription):
38+
def __init__(self, notifications: list[protocol.Notification]) -> None:
39+
super().__init__()
40+
self._notifications = notifications
41+
42+
async def next(self) -> protocol.Notification:
43+
if not self._notifications:
44+
raise StopAsyncIteration
45+
return self._notifications.pop(0)
46+
47+
3748
class _FakeThread:
3849
def __init__(self) -> None:
3950
self.id = "thr-1"
@@ -76,6 +87,94 @@ def test_async_events_client_subscribe_delegates_to_session() -> None:
7687
assert session.calls == [(["turn/completed"], None)]
7788

7889

90+
def test_async_events_client_subscribe_command_exec_output_filters_by_process_id() -> None:
91+
session = _FakeSession()
92+
events = AsyncEventsClient(session) # type: ignore[arg-type]
93+
94+
subscription = events.subscribe_command_exec_output("proc-1")
95+
96+
assert subscription == "subscription"
97+
methods, predicate = session.calls[0]
98+
assert methods == {"command/exec/outputDelta"}
99+
assert callable(predicate)
100+
matching = protocol.CommandExecOutputDeltaNotificationModel.model_validate(
101+
{
102+
"method": "command/exec/outputDelta",
103+
"params": {
104+
"capReached": False,
105+
"deltaBase64": "aGVsbG8K",
106+
"processId": "proc-1",
107+
"stream": "stdout",
108+
},
109+
}
110+
)
111+
other_process = protocol.CommandExecOutputDeltaNotificationModel.model_validate(
112+
{
113+
"method": "command/exec/outputDelta",
114+
"params": {
115+
"capReached": False,
116+
"deltaBase64": "aGVsbG8K",
117+
"processId": "proc-2",
118+
"stream": "stdout",
119+
},
120+
}
121+
)
122+
123+
assert predicate(matching) is True
124+
assert predicate(other_process) is False
125+
126+
127+
def test_async_events_client_subscribe_process_events_filters_by_process_handle() -> None:
128+
session = _FakeSession()
129+
events = AsyncEventsClient(session) # type: ignore[arg-type]
130+
131+
subscription = events.subscribe_process_events("proc-handle-1")
132+
133+
assert subscription == "subscription"
134+
methods, predicate = session.calls[0]
135+
assert methods == {"process/outputDelta", "process/exited"}
136+
assert callable(predicate)
137+
output = protocol.ProcessOutputDeltaNotificationModel.model_validate(
138+
{
139+
"method": "process/outputDelta",
140+
"params": {
141+
"capReached": False,
142+
"deltaBase64": "aGVsbG8K",
143+
"processHandle": "proc-handle-1",
144+
"stream": "stdout",
145+
},
146+
}
147+
)
148+
exited = protocol.ProcessExitedNotificationModel.model_validate(
149+
{
150+
"method": "process/exited",
151+
"params": {
152+
"exitCode": 0,
153+
"processHandle": "proc-handle-1",
154+
"stderr": "",
155+
"stderrCapReached": False,
156+
"stdout": "",
157+
"stdoutCapReached": False,
158+
},
159+
}
160+
)
161+
other_handle = protocol.ProcessOutputDeltaNotificationModel.model_validate(
162+
{
163+
"method": "process/outputDelta",
164+
"params": {
165+
"capReached": False,
166+
"deltaBase64": "aGVsbG8K",
167+
"processHandle": "proc-handle-2",
168+
"stream": "stdout",
169+
},
170+
}
171+
)
172+
173+
assert predicate(output) is True
174+
assert predicate(exited) is True
175+
assert predicate(other_handle) is False
176+
177+
79178
def test_async_turn_stream_scope_predicate_filters_by_thread_and_turn() -> None:
80179
predicate = AsyncTurnStream._scope_predicate("thr-1", "turn-1")
81180

@@ -339,6 +438,65 @@ async def scenario() -> None:
339438
asyncio.run(scenario())
340439

341440

441+
def test_async_turn_stream_raises_and_closes_on_non_retryable_error_notification() -> None:
442+
error_notification = protocol.ErrorNotificationModel.model_validate(
443+
{
444+
"method": "error",
445+
"params": {
446+
"threadId": "thr-1",
447+
"turnId": "turn-1",
448+
"willRetry": False,
449+
"error": {
450+
"message": "model unavailable",
451+
"additionalDetails": "try another model",
452+
},
453+
},
454+
}
455+
)
456+
457+
async def scenario() -> None:
458+
subscription = _QueuedSubscription([error_notification])
459+
stream = AsyncTurnStream(
460+
_FakeThread(), # type: ignore[arg-type]
461+
subscription, # type: ignore[arg-type]
462+
protocol.Turn.model_validate(_turn_payload(status="inProgress")),
463+
)
464+
465+
with pytest.raises(AppServerTurnError, match="model unavailable: try another model"):
466+
await stream.__anext__()
467+
468+
assert subscription.closed is True
469+
470+
asyncio.run(scenario())
471+
472+
473+
def test_async_turn_stream_yields_retryable_error_notification() -> None:
474+
error_notification = protocol.ErrorNotificationModel.model_validate(
475+
{
476+
"method": "error",
477+
"params": {
478+
"threadId": "thr-1",
479+
"turnId": "turn-1",
480+
"willRetry": True,
481+
"error": {"message": "temporary outage"},
482+
},
483+
}
484+
)
485+
486+
async def scenario() -> None:
487+
subscription = _QueuedSubscription([error_notification])
488+
stream = AsyncTurnStream(
489+
_FakeThread(), # type: ignore[arg-type]
490+
subscription, # type: ignore[arg-type]
491+
protocol.Turn.model_validate(_turn_payload(status="inProgress")),
492+
)
493+
494+
assert await stream.__anext__() == error_notification
495+
assert subscription.closed is False
496+
497+
asyncio.run(scenario())
498+
499+
342500
def test_async_turn_stream_raise_for_terminal_status_requires_completion() -> None:
343501
stream = AsyncTurnStream(
344502
_FakeThread(), # type: ignore[arg-type]

tests/test_stream_interaction.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -323,21 +323,19 @@ def test_consume_two_turns_same_thread() -> None:
323323
client.close()
324324

325325

326-
def test_turn_completed_not_in_notification_methods_does_not_hang() -> None:
326+
def test_turn_stream_yields_terminal_interaction_and_does_not_hang() -> None:
327327
"""Reproduce review-action handling when terminalInteraction is emitted.
328328
329-
codex-review-action handles `item/commandExecution/terminalInteraction`, but
330-
`_TURN_STREAM_NOTIFICATION_METHODS` does not subscribe to that method.
331-
This test verifies such events are effectively dropped for the turn stream and
332-
do not cause `wait()` to hang.
329+
codex-review-action handles `item/commandExecution/terminalInteraction`;
330+
the turn stream should now deliver it as a typed protocol notification and
331+
still allow a later `wait()` call to finish.
333332
"""
334333

335334
client, transport = _make_sync_client()
336335
try:
337336
thread = client.start_thread()
338337
stream = thread.run("Terminal interaction compatibility")
339338

340-
# This method is intentionally *not* part of turn stream subscribed methods.
341339
transport.push(
342340
{
343341
"method": "item/commandExecution/terminalInteraction",
@@ -374,8 +372,8 @@ def test_turn_completed_not_in_notification_methods_does_not_hang() -> None:
374372
task_complete, events = _consume_like_review_action(stream)
375373

376374
assert task_complete is True
377-
# If terminalInteraction had been routed to stream, we'd see an extra event.
378375
assert [type(event) for event in events] == [
376+
protocol.ItemCommandExecutionTerminalInteractionNotification,
379377
protocol.ItemCompletedNotificationModel,
380378
protocol.TurnCompletedNotificationModel,
381379
]

0 commit comments

Comments
 (0)