|
5 | 5 | import pytest |
6 | 6 |
|
7 | 7 | from codex.app_server._async_client import AsyncEventsClient, AsyncTurnStream |
8 | | -from codex.app_server.errors import AppServerProtocolError |
| 8 | +from codex.app_server.errors import AppServerProtocolError, AppServerTurnError |
9 | 9 | from codex.app_server.models import ReviewResult |
10 | 10 | from codex.protocol import types as protocol |
11 | 11 |
|
@@ -34,6 +34,17 @@ def update_predicate(self, predicate: object) -> None: |
34 | 34 | self.updated_predicate = predicate |
35 | 35 |
|
36 | 36 |
|
| 37 | +class _QueuedSubscription(_FakeSubscription): |
| 38 | + def __init__(self, notifications: list[protocol.Notification]) -> None: |
| 39 | + super().__init__() |
| 40 | + self._notifications = notifications |
| 41 | + |
| 42 | + async def next(self) -> protocol.Notification: |
| 43 | + if not self._notifications: |
| 44 | + raise StopAsyncIteration |
| 45 | + return self._notifications.pop(0) |
| 46 | + |
| 47 | + |
37 | 48 | class _FakeThread: |
38 | 49 | def __init__(self) -> None: |
39 | 50 | self.id = "thr-1" |
@@ -76,6 +87,94 @@ def test_async_events_client_subscribe_delegates_to_session() -> None: |
76 | 87 | assert session.calls == [(["turn/completed"], None)] |
77 | 88 |
|
78 | 89 |
|
| 90 | +def test_async_events_client_subscribe_command_exec_output_filters_by_process_id() -> None: |
| 91 | + session = _FakeSession() |
| 92 | + events = AsyncEventsClient(session) # type: ignore[arg-type] |
| 93 | + |
| 94 | + subscription = events.subscribe_command_exec_output("proc-1") |
| 95 | + |
| 96 | + assert subscription == "subscription" |
| 97 | + methods, predicate = session.calls[0] |
| 98 | + assert methods == {"command/exec/outputDelta"} |
| 99 | + assert callable(predicate) |
| 100 | + matching = protocol.CommandExecOutputDeltaNotificationModel.model_validate( |
| 101 | + { |
| 102 | + "method": "command/exec/outputDelta", |
| 103 | + "params": { |
| 104 | + "capReached": False, |
| 105 | + "deltaBase64": "aGVsbG8K", |
| 106 | + "processId": "proc-1", |
| 107 | + "stream": "stdout", |
| 108 | + }, |
| 109 | + } |
| 110 | + ) |
| 111 | + other_process = protocol.CommandExecOutputDeltaNotificationModel.model_validate( |
| 112 | + { |
| 113 | + "method": "command/exec/outputDelta", |
| 114 | + "params": { |
| 115 | + "capReached": False, |
| 116 | + "deltaBase64": "aGVsbG8K", |
| 117 | + "processId": "proc-2", |
| 118 | + "stream": "stdout", |
| 119 | + }, |
| 120 | + } |
| 121 | + ) |
| 122 | + |
| 123 | + assert predicate(matching) is True |
| 124 | + assert predicate(other_process) is False |
| 125 | + |
| 126 | + |
| 127 | +def test_async_events_client_subscribe_process_events_filters_by_process_handle() -> None: |
| 128 | + session = _FakeSession() |
| 129 | + events = AsyncEventsClient(session) # type: ignore[arg-type] |
| 130 | + |
| 131 | + subscription = events.subscribe_process_events("proc-handle-1") |
| 132 | + |
| 133 | + assert subscription == "subscription" |
| 134 | + methods, predicate = session.calls[0] |
| 135 | + assert methods == {"process/outputDelta", "process/exited"} |
| 136 | + assert callable(predicate) |
| 137 | + output = protocol.ProcessOutputDeltaNotificationModel.model_validate( |
| 138 | + { |
| 139 | + "method": "process/outputDelta", |
| 140 | + "params": { |
| 141 | + "capReached": False, |
| 142 | + "deltaBase64": "aGVsbG8K", |
| 143 | + "processHandle": "proc-handle-1", |
| 144 | + "stream": "stdout", |
| 145 | + }, |
| 146 | + } |
| 147 | + ) |
| 148 | + exited = protocol.ProcessExitedNotificationModel.model_validate( |
| 149 | + { |
| 150 | + "method": "process/exited", |
| 151 | + "params": { |
| 152 | + "exitCode": 0, |
| 153 | + "processHandle": "proc-handle-1", |
| 154 | + "stderr": "", |
| 155 | + "stderrCapReached": False, |
| 156 | + "stdout": "", |
| 157 | + "stdoutCapReached": False, |
| 158 | + }, |
| 159 | + } |
| 160 | + ) |
| 161 | + other_handle = protocol.ProcessOutputDeltaNotificationModel.model_validate( |
| 162 | + { |
| 163 | + "method": "process/outputDelta", |
| 164 | + "params": { |
| 165 | + "capReached": False, |
| 166 | + "deltaBase64": "aGVsbG8K", |
| 167 | + "processHandle": "proc-handle-2", |
| 168 | + "stream": "stdout", |
| 169 | + }, |
| 170 | + } |
| 171 | + ) |
| 172 | + |
| 173 | + assert predicate(output) is True |
| 174 | + assert predicate(exited) is True |
| 175 | + assert predicate(other_handle) is False |
| 176 | + |
| 177 | + |
79 | 178 | def test_async_turn_stream_scope_predicate_filters_by_thread_and_turn() -> None: |
80 | 179 | predicate = AsyncTurnStream._scope_predicate("thr-1", "turn-1") |
81 | 180 |
|
@@ -339,6 +438,102 @@ async def scenario() -> None: |
339 | 438 | asyncio.run(scenario()) |
340 | 439 |
|
341 | 440 |
|
| 441 | +def test_async_turn_stream_raises_and_closes_on_non_retryable_error_notification() -> None: |
| 442 | + error_notification = protocol.ErrorNotificationModel.model_validate( |
| 443 | + { |
| 444 | + "method": "error", |
| 445 | + "params": { |
| 446 | + "threadId": "thr-1", |
| 447 | + "turnId": "turn-1", |
| 448 | + "willRetry": False, |
| 449 | + "error": { |
| 450 | + "message": "model unavailable", |
| 451 | + "additionalDetails": "try another model", |
| 452 | + }, |
| 453 | + }, |
| 454 | + } |
| 455 | + ) |
| 456 | + |
| 457 | + async def scenario() -> None: |
| 458 | + subscription = _QueuedSubscription([error_notification]) |
| 459 | + stream = AsyncTurnStream( |
| 460 | + _FakeThread(), # type: ignore[arg-type] |
| 461 | + subscription, # type: ignore[arg-type] |
| 462 | + protocol.Turn.model_validate(_turn_payload(status="inProgress")), |
| 463 | + ) |
| 464 | + |
| 465 | + with pytest.raises(AppServerTurnError, match="model unavailable: try another model"): |
| 466 | + await stream.__anext__() |
| 467 | + |
| 468 | + assert subscription.closed is True |
| 469 | + |
| 470 | + asyncio.run(scenario()) |
| 471 | + |
| 472 | + |
| 473 | +def test_async_turn_stream_yields_retryable_error_notification() -> None: |
| 474 | + error_notification = protocol.ErrorNotificationModel.model_validate( |
| 475 | + { |
| 476 | + "method": "error", |
| 477 | + "params": { |
| 478 | + "threadId": "thr-1", |
| 479 | + "turnId": "turn-1", |
| 480 | + "willRetry": True, |
| 481 | + "error": {"message": "temporary outage"}, |
| 482 | + }, |
| 483 | + } |
| 484 | + ) |
| 485 | + |
| 486 | + async def scenario() -> None: |
| 487 | + subscription = _QueuedSubscription([error_notification]) |
| 488 | + stream = AsyncTurnStream( |
| 489 | + _FakeThread(), # type: ignore[arg-type] |
| 490 | + subscription, # type: ignore[arg-type] |
| 491 | + protocol.Turn.model_validate(_turn_payload(status="inProgress")), |
| 492 | + ) |
| 493 | + |
| 494 | + assert await stream.__anext__() == error_notification |
| 495 | + assert subscription.closed is False |
| 496 | + |
| 497 | + asyncio.run(scenario()) |
| 498 | + |
| 499 | + |
| 500 | +def test_async_turn_stream_wait_preserves_retryable_error_notifications() -> None: |
| 501 | + error_notification = protocol.ErrorNotificationModel.model_validate( |
| 502 | + { |
| 503 | + "method": "error", |
| 504 | + "params": { |
| 505 | + "threadId": "thr-1", |
| 506 | + "turnId": "turn-1", |
| 507 | + "willRetry": True, |
| 508 | + "error": { |
| 509 | + "message": "temporary outage", |
| 510 | + "additionalDetails": "retrying with fallback", |
| 511 | + }, |
| 512 | + }, |
| 513 | + } |
| 514 | + ) |
| 515 | + turn_completed = protocol.TurnCompletedNotificationModel.model_validate( |
| 516 | + { |
| 517 | + "method": "turn/completed", |
| 518 | + "params": {"threadId": "thr-1", "turn": _turn_payload(status="completed")}, |
| 519 | + } |
| 520 | + ) |
| 521 | + |
| 522 | + async def scenario() -> None: |
| 523 | + subscription = _QueuedSubscription([error_notification, turn_completed]) |
| 524 | + stream = AsyncTurnStream( |
| 525 | + _FakeThread(), # type: ignore[arg-type] |
| 526 | + subscription, # type: ignore[arg-type] |
| 527 | + protocol.Turn.model_validate(_turn_payload(status="inProgress")), |
| 528 | + ) |
| 529 | + |
| 530 | + assert await stream.wait() is stream |
| 531 | + assert stream.retryable_error_notifications == (error_notification,) |
| 532 | + assert stream.retryable_errors == (error_notification.params.error,) |
| 533 | + |
| 534 | + asyncio.run(scenario()) |
| 535 | + |
| 536 | + |
342 | 537 | def test_async_turn_stream_raise_for_terminal_status_requires_completion() -> None: |
343 | 538 | stream = AsyncTurnStream( |
344 | 539 | _FakeThread(), # type: ignore[arg-type] |
|
0 commit comments