Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions src/agents/run_internal/tool_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ async def _collect_runs_by_approval(
rejection_items.append(rejection_item)
continue

if approval_status is True:
approved_runs.append(run)
continue

needs_approval = True
if needs_approval_checker:
try:
Expand All @@ -424,25 +428,22 @@ async def _collect_runs_by_approval(
approved_runs.append(run)
continue

if approval_status is True:
approved_runs.append(run)
else:
function_tool = get_mapping_or_attr(run, "function_tool")
pending_item = existing_pending or ToolApprovalItem(
agent=agent,
raw_item=get_mapping_or_attr(run, "tool_call"),
tool_name=tool_name,
tool_namespace=get_tool_call_namespace(get_mapping_or_attr(run, "tool_call")),
tool_origin=(
get_function_tool_origin(function_tool)
if isinstance(function_tool, FunctionTool)
else None
),
tool_lookup_key=get_function_tool_lookup_key_for_call(
get_mapping_or_attr(run, "tool_call")
),
)
pending_interruption_adder(pending_item)
function_tool = get_mapping_or_attr(run, "function_tool")
pending_item = existing_pending or ToolApprovalItem(
agent=agent,
raw_item=get_mapping_or_attr(run, "tool_call"),
tool_name=tool_name,
tool_namespace=get_tool_call_namespace(get_mapping_or_attr(run, "tool_call")),
tool_origin=(
get_function_tool_origin(function_tool)
if isinstance(function_tool, FunctionTool)
else None
),
tool_lookup_key=get_function_tool_lookup_key_for_call(
get_mapping_or_attr(run, "tool_call")
),
)
pending_interruption_adder(pending_item)

return approved_runs, rejection_items

Expand Down
68 changes: 67 additions & 1 deletion tests/test_hitl_error_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@
ToolRunShellCall,
extract_tool_call_id,
)
from agents.run_internal.tool_planning import _select_function_tool_runs_for_resume
from agents.run_internal.tool_planning import (
_collect_runs_by_approval,
_select_function_tool_runs_for_resume,
)
from agents.run_state import RunState as RunStateClass
from agents.tool import HostedMCPTool
from agents.usage import Usage
Expand Down Expand Up @@ -1247,6 +1250,69 @@ async def _record_rejection(
assert rejections == []


@pytest.mark.asyncio
async def test_collect_runs_by_approval_skips_checker_when_status_resolved() -> None:
"""Approved/rejected shell calls must not invoke needs_approval_checker.

Mirrors #3229 for non-function tools: when the approval status is already
True or False, a user-supplied checker (which may have side effects, hit
the network, or raise) must be short-circuited.
"""
shell_tool = ShellTool(executor=lambda _req: "ok", needs_approval=True)
approved_call = make_shell_call("approved-shell")
rejected_call = make_shell_call("rejected-shell")
agent = Agent(name="agent")
context_wrapper = make_context_wrapper()
context_wrapper.approve_tool(
ToolApprovalItem(
agent=agent,
raw_item=cast(dict[str, Any], approved_call),
tool_name=shell_tool.name,
)
)
context_wrapper.reject_tool(
ToolApprovalItem(
agent=agent,
raw_item=cast(dict[str, Any], rejected_call),
tool_name=shell_tool.name,
)
)

runs = [
ToolRunShellCall(tool_call=approved_call, shell_tool=shell_tool),
ToolRunShellCall(tool_call=rejected_call, shell_tool=shell_tool),
]
checker_calls: list[str] = []

async def _needs_approval(run: ToolRunShellCall) -> bool:
checker_calls.append(run.tool_call["call_id"])
raise AssertionError("checker must not run for resolved approvals")

async def _build_rejection(run: ToolRunShellCall, call_id: str) -> RunItem:
return ToolCallOutputItem(
output="rejected",
raw_item={"type": "function_call_output", "call_id": call_id, "output": "rejected"},
agent=agent,
)

approved, rejections = await _collect_runs_by_approval(
runs,
call_id_extractor=lambda run: run.tool_call["call_id"],
tool_name_resolver=lambda run: run.shell_tool.name,
rejection_builder=_build_rejection,
context_wrapper=context_wrapper,
approval_items_by_call_id={},
agent=agent,
pending_interruption_adder=lambda _item: None,
needs_approval_checker=_needs_approval,
output_exists_checker=lambda _call_id: False,
)

assert checker_calls == []
assert approved == [runs[0]]
assert len(rejections) == 1


@pytest.mark.asyncio
async def test_resume_rebuilds_function_runs_from_object_approvals() -> None:
"""Rebuild should handle ResponseFunctionToolCall approval items."""
Expand Down