Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def run(
Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
specific tool invocation.
:param confirmation_strategy_context:
Optional dictionary for passing request-scoped resources. Useful in web/server environments
to provide per-request objects (e.g., WebSocket connections, async queues, Redis pub/sub clients)
that strategies can use for non-blocking user interaction.
Optional dictionary for passing request-scoped resources. Not used by this strategy but included for
interface compatibility.

:returns:
A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a
Expand Down Expand Up @@ -140,7 +139,8 @@ async def run_async(
:param tool_call_id:
Optional unique identifier for the tool call.
:param confirmation_strategy_context:
Optional dictionary for passing request-scoped resources.
Optional dictionary for passing request-scoped resources. Not used by this strategy but included for
interface compatibility.

:returns:
A ToolExecutionDecision indicating whether to execute the tool with the given parameters.
Expand Down Expand Up @@ -263,7 +263,8 @@ async def run_async(
:param tool_call_id:
Optional unique identifier for the tool call.
:param confirmation_strategy_context:
Optional dictionary for passing request-scoped resources.
Optional dictionary for passing request-scoped resources. Not used by this strategy but included for
interface compatibility.

:raises HITLBreakpointException:
Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def update_after_confirmation(
confirmation_result: ConfirmationUIResult,
) -> None:
"""Update the policy based on the confirmation UI result."""
pass
return

def to_dict(self) -> dict[str, Any]:
"""Serialize the policy to a dictionary."""
Expand All @@ -64,11 +64,12 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfirmationPolicy":
class ConfirmationStrategy(Protocol):
def run(
self,
*,
tool_name: str,
tool_description: str,
tool_params: dict[str, Any],
tool_call_id: str | None = None,
**kwargs: dict[str, Any] | None,
confirmation_strategy_context: dict[str, Any] | None = None,
) -> ToolExecutionDecision:
"""
Run the confirmation strategy for a given tool and its parameters.
Expand All @@ -78,9 +79,8 @@ def run(
:param tool_params: The parameters to be passed to the tool.
:param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
the decision with a specific tool invocation.
:param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context`
for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server
environments.
:param confirmation_strategy_context: Optional context dictionary for passing request-scoped resources
(e.g., WebSocket connections, async queues) in web/server environments.

:returns:
The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
Expand All @@ -89,11 +89,12 @@ def run(

async def run_async(
self,
*,
tool_name: str,
tool_description: str,
tool_params: dict[str, Any],
tool_call_id: str | None = None,
**kwargs: dict[str, Any] | None,
confirmation_strategy_context: dict[str, Any] | None = None,
) -> ToolExecutionDecision:
"""
Async version of run. Run the confirmation strategy for a given tool and its parameters.
Expand All @@ -105,9 +106,8 @@ async def run_async(
:param tool_params: The parameters to be passed to the tool.
:param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
the decision with a specific tool invocation.
:param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context`
for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server
environments.
:param confirmation_strategy_context: Optional context dictionary for passing request-scoped resources
(e.g., WebSocket connections, async queues) in web/server environments.

:returns:
The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
Expand Down
13 changes: 13 additions & 0 deletions test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def run_agent(
snapshot = None
if snapshot_file_path:
snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path)
assert snapshot.agent_snapshot is not None

# Add any new tool execution decisions to the snapshot
if tool_execution_decisions:
Expand All @@ -152,6 +153,7 @@ def run_pipeline_with_agent(
snapshot = None
if snapshot_file_path:
snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path)
assert snapshot.agent_snapshot is not None

# Add any new tool execution decisions to the snapshot
if tool_execution_decisions:
Expand All @@ -175,6 +177,7 @@ async def run_agent_async(
snapshot = None
if snapshot_file_path:
snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path)
assert snapshot.agent_snapshot is not None

# Add any new tool execution decisions to the snapshot
if tool_execution_decisions:
Expand Down Expand Up @@ -284,6 +287,7 @@ def test_from_dict(self, tools, confirmation_strategies, monkeypatch):
assert deserialized_agent.to_dict() == agent.to_dict()
assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator)
assert len(deserialized_agent.tools) == 1
assert isinstance(deserialized_agent.tools[0], Tool)
assert deserialized_agent.tools[0].name == "addition_tool"
assert isinstance(deserialized_agent._tool_invoker, type(agent._tool_invoker))
assert isinstance(deserialized_agent._confirmation_strategies["addition_tool"], BlockingConfirmationStrategy)
Expand Down Expand Up @@ -316,6 +320,7 @@ def test_get_tool_calls_and_descriptions_from_snapshot_no_mutation_of_snapshot(s
original_snapshot = copy.deepcopy(loaded_snapshot)

# Extract tool calls and descriptions
assert loaded_snapshot.agent_snapshot is not None
_ = get_tool_calls_and_descriptions_from_snapshot(
agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
)
Expand All @@ -341,6 +346,7 @@ def test_run_blocking_confirmation_strategy_modify(self, tools):
result = agent.run([ChatMessage.from_user("What is 2+2?")])

assert isinstance(result["last_message"], ChatMessage)
assert result["last_message"].text is not None
assert "5" in result["last_message"].text

@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
Expand All @@ -362,6 +368,7 @@ def test_run_breakpoint_confirmation_strategy_modify(self, tools, tmp_path):
while result is None:
# Load the latest snapshot from disk and prep data for front-end
loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path))
assert loaded_snapshot.agent_snapshot is not None
serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot(
agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
)
Expand All @@ -379,6 +386,7 @@ def test_run_breakpoint_confirmation_strategy_modify(self, tools, tmp_path):
# Step 3: Final result
last_message = result["last_message"]
assert isinstance(last_message, ChatMessage)
assert last_message.text is not None
assert "5" in last_message.text

@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
Expand All @@ -402,6 +410,7 @@ def test_run_in_pipeline_breakpoint_confirmation_strategy_modify(self, tools, tm
while result is None:
# Load the latest snapshot from disk and prep data for front-end
loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path))
assert loaded_snapshot.agent_snapshot is not None
serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot(
agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
)
Expand All @@ -419,6 +428,7 @@ def test_run_in_pipeline_breakpoint_confirmation_strategy_modify(self, tools, tm
# Step 3: Final result
last_message = result["agent"]["last_message"]
assert isinstance(last_message, ChatMessage)
assert last_message.text is not None
assert "5" in last_message.text

@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
Expand All @@ -440,6 +450,7 @@ async def test_run_async_blocking_confirmation_strategy_modify(self, tools):
result = await agent.run_async([ChatMessage.from_user("What is 2+2?")])

assert isinstance(result["last_message"], ChatMessage)
assert result["last_message"].text is not None
assert "5" in result["last_message"].text

@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
Expand All @@ -462,6 +473,7 @@ async def test_run_async_breakpoint_confirmation_strategy_modify(self, tools, tm
while result is None:
# Load the latest snapshot from disk and prep data for front-end
loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path))
assert loaded_snapshot.agent_snapshot is not None
serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot(
agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
)
Expand All @@ -479,6 +491,7 @@ async def test_run_async_breakpoint_confirmation_strategy_modify(self, tools, tm
# Step 3: Final result
last_message = result["last_message"]
assert isinstance(last_message, ChatMessage)
assert last_message.text is not None
assert "5" in last_message.text


Expand Down
Loading