Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
472 changes: 78 additions & 394 deletions haystack/components/agents/agent.py

Large diffs are not rendered by default.

26 changes: 7 additions & 19 deletions haystack/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Any

from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, PipelineSnapshot, ToolBreakpoint
from haystack.dataclasses.breakpoints import Breakpoint, PipelineSnapshot


class PipelineError(Exception):
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
pipeline_snapshot: PipelineSnapshot | None = None,
pipeline_snapshot_file_path: str | None = None,
*,
break_point: AgentBreakpoint | Breakpoint | ToolBreakpoint | None = None,
break_point: Breakpoint | None = None,
) -> None:
super().__init__(message)
self.component = component
Expand All @@ -127,7 +127,7 @@ def __init__(
raise ValueError("Either pipeline_snapshot or break_point must be provided.")

@classmethod
def from_triggered_breakpoint(cls, break_point: Breakpoint | ToolBreakpoint) -> "BreakpointException":
def from_triggered_breakpoint(cls, break_point: Breakpoint) -> "BreakpointException":
"""
Create a BreakpointException from a triggered breakpoint.
"""
Expand All @@ -137,37 +137,25 @@ def from_triggered_breakpoint(cls, break_point: Breakpoint | ToolBreakpoint) ->
@property
def inputs(self) -> dict[str, Any] | None:
"""
Returns the inputs of the pipeline or agent at the breakpoint.

If an AgentBreakpoint caused this exception, returns the inputs of the agent's internal components.
Otherwise, returns the current inputs of the pipeline.
Returns the current inputs of the pipeline at the breakpoint.
"""
if not self.pipeline_snapshot:
return None

if self.pipeline_snapshot.agent_snapshot:
return self.pipeline_snapshot.agent_snapshot.component_inputs
return self.pipeline_snapshot.pipeline_state.inputs

@property
def results(self) -> dict[str, Any] | None:
"""
Returns the results of the pipeline or agent at the breakpoint.

If an AgentBreakpoint caused this exception, returns the current results of the agent.
Otherwise, returns the current outputs of the pipeline.
Returns the current outputs of the pipeline at the breakpoint.
"""
if not self.pipeline_snapshot:
return None

if self.pipeline_snapshot.agent_snapshot:
return self.pipeline_snapshot.agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"]
return self.pipeline_snapshot.pipeline_state.pipeline_outputs

@property
def break_point(self) -> AgentBreakpoint | Breakpoint | ToolBreakpoint:
def break_point(self) -> Breakpoint:
"""
Returns the Breakpoint or AgentBreakpoint that caused this exception, if available.
Returns the Breakpoint that caused this exception.

If a specific break point was provided during initialization, it is returned.
Otherwise, if the pipeline snapshot contains a break point, that is returned.
Expand Down
244 changes: 12 additions & 232 deletions haystack/core/pipeline/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,15 @@
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import Any

from networkx import MultiDiGraph

from haystack import logging
from haystack.core.errors import PipelineInvalidPipelineSnapshotError
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.breakpoints import (
AgentBreakpoint,
AgentSnapshot,
Breakpoint,
PipelineSnapshot,
PipelineState,
ToolBreakpoint,
)
from haystack.dataclasses.breakpoints import Breakpoint, PipelineSnapshot, PipelineState
from haystack.utils.base_serialization import _serialize_value_with_schema
from haystack.utils.misc import _get_output_dir

if TYPE_CHECKING:
from haystack.components.agents.agent import _ExecutionContext
from haystack.tools import ToolsType

logger = logging.getLogger(__name__)

Expand All @@ -54,34 +41,17 @@ def _is_snapshot_save_enabled() -> bool:
return value in ("true", "1")


def _validate_break_point_against_pipeline(break_point: Breakpoint | AgentBreakpoint, graph: MultiDiGraph) -> None:
def _validate_break_point_against_pipeline(break_point: Breakpoint, graph: MultiDiGraph) -> None:
"""
Validates the breakpoints passed to the pipeline.

Makes sure the breakpoint contains a valid components registered in the pipeline.

:param break_point: a breakpoint to validate, can be Breakpoint or AgentBreakpoint
:param break_point: a breakpoint to validate
"""

# all Breakpoints must refer to a valid component in the pipeline
if isinstance(break_point, Breakpoint) and break_point.component_name not in graph.nodes:
if break_point.component_name not in graph.nodes:
raise ValueError(f"break_point {break_point} is not a registered component in the pipeline")

if isinstance(break_point, AgentBreakpoint):
breakpoint_agent_component = graph.nodes.get(break_point.agent_name)
if not breakpoint_agent_component:
raise ValueError(f"break_point {break_point} is not a registered Agent component in the pipeline")

if isinstance(break_point.break_point, ToolBreakpoint):
instance = breakpoint_agent_component["instance"]
for tool in instance.tools:
if break_point.break_point.tool_name == tool.name:
break
else:
raise ValueError(
f"break_point {break_point.break_point} is not a registered tool in the Agent component"
)


def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnapshot, graph: MultiDiGraph) -> None:
"""
Expand Down Expand Up @@ -121,11 +91,7 @@ def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnap
f"are not part of the current pipeline."
)

if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
component_name = pipeline_snapshot.break_point.agent_name
else:
component_name = pipeline_snapshot.break_point.component_name

component_name = pipeline_snapshot.break_point.component_name
visit_count = pipeline_snapshot.pipeline_state.component_visits[component_name]

logger.info(
Expand Down Expand Up @@ -216,30 +182,18 @@ def _save_pipeline_snapshot(
return None

break_point = pipeline_snapshot.break_point
snapshot_file_path = (
break_point.break_point.snapshot_file_path
if isinstance(break_point, AgentBreakpoint)
else break_point.snapshot_file_path
)
snapshot_file_path = break_point.snapshot_file_path

if snapshot_file_path is None:
return None

dt = pipeline_snapshot.timestamp or datetime.now()
snapshot_dir = Path(snapshot_file_path)

# Generate filename
# We check if the agent_name is provided to differentiate between agent and non-agent breakpoints
if isinstance(break_point, AgentBreakpoint):
agent_name = break_point.agent_name
component_name = break_point.break_point.component_name
else:
component_name = break_point.component_name
agent_name = None

component_name = break_point.component_name
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
timestamp = dt.strftime("%Y_%m_%d_%H_%M_%S")
file_name = f"{agent_name + '_' if agent_name else ''}{component_name}_{visit_nr}_{timestamp}.json"
file_name = f"{component_name}_{visit_nr}_{timestamp}.json"
full_path = snapshot_dir / file_name

try:
Expand All @@ -262,7 +216,7 @@ def _create_pipeline_snapshot(
*,
inputs: dict[str, Any],
component_inputs: dict[str, Any],
break_point: AgentBreakpoint | Breakpoint,
break_point: Breakpoint,
component_visits: dict[str, int],
original_input_data: dict[str, Any],
ordered_component_names: list[str],
Expand All @@ -274,7 +228,7 @@ def _create_pipeline_snapshot(

:param inputs: The current pipeline snapshot inputs.
:param component_inputs: The inputs to the component that triggered the breakpoint.
:param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint.
:param break_point: The breakpoint that triggered the snapshot.
:param component_visits: The visit count of the component that triggered the breakpoint.
:param original_input_data: The original input data.
:param ordered_component_names: The ordered component names.
Expand All @@ -283,10 +237,7 @@ def _create_pipeline_snapshot(
:returns:
A PipelineSnapshot containing the state of the pipeline at the point of the breakpoint.
"""
if isinstance(break_point, AgentBreakpoint):
component_name = break_point.agent_name
else:
component_name = break_point.component_name
component_name = break_point.component_name

transformed_original_input_data = _transform_json_structure(original_input_data)
transformed_inputs = _transform_json_structure({**inputs, component_name: component_inputs})
Expand Down Expand Up @@ -343,31 +294,6 @@ def _transform_json_structure(data: dict[str, Any] | list[Any] | Any) -> Any:
return data


def _create_agent_snapshot(
*, component_visits: dict[str, int], agent_breakpoint: AgentBreakpoint, component_inputs: dict[str, Any]
) -> AgentSnapshot:
"""
Create a snapshot of the agent's state.

:param component_visits: The visit counts for the agent's components.
:param agent_breakpoint: AgentBreakpoint object containing breakpoints
:return: An AgentSnapshot containing the agent's state and component visits.
"""
serialized_chat_generator = _serialize_agent_component_inputs(
component_name="chat_generator", component_inputs=component_inputs["chat_generator"]
)
serialized_tool_invoker = _serialize_agent_component_inputs(
component_name="tool_invoker", component_inputs=component_inputs["tool_invoker"]
)

return AgentSnapshot(
component_inputs={"chat_generator": serialized_chat_generator, "tool_invoker": serialized_tool_invoker},
component_visits=component_visits,
break_point=agent_breakpoint,
timestamp=datetime.now(),
)


def _serialize_with_field_fallback(payload: Any, *, description: str) -> dict[str, Any]:
"""
Serialize a payload and, on failure, retry field-by-field to preserve resumable fields.
Expand Down Expand Up @@ -417,149 +343,3 @@ def _serialize_with_field_fallback(payload: Any, *, description: str) -> dict[st
"serialization_schema": {"type": "object", "properties": serialized_properties},
"serialized_data": serialized_data,
}


def _serialize_agent_component_inputs(component_name: str, component_inputs: dict[str, Any]) -> dict[str, Any]:
"""
Serialize agent component inputs while preserving resumable fields whenever possible.

Thin wrapper around :func:`_serialize_with_field_fallback` that supplies an agent-specific label
for the warning messages.

:param component_name: Name of the agent sub-component (e.g. ``chat_generator`` or ``tool_invoker``).
:param component_inputs: Runtime inputs for that sub-component.
:returns: A serialized payload that is always a structurally valid ``{"serialization_schema",
"serialized_data"}`` pair. When every field fails to serialize, an empty-but-valid object
payload is returned so that ``_deserialize_value_with_schema`` can still load it (for example
when resuming from a ``ToolBreakpoint`` where the sub-component's inputs are not strictly required).
"""
return _serialize_with_field_fallback(component_inputs, description=f"the agent's {component_name} inputs")


def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools: "ToolsType") -> None:
"""
Validates the AgentBreakpoint passed to the agent.

Validates that the tool name in ToolBreakpoints correspond to a tool available in the agent.

:param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components.
:param tools: A list of Tool and/or Toolset objects, or a Toolset that the agent can use.
:raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools.
"""
from haystack.tools.utils import flatten_tools_or_toolsets # avoid circular import

available_tool_names = {tool.name for tool in flatten_tools_or_toolsets(tools)}
tool_breakpoint = agent_breakpoint.break_point
# Assert added for mypy to pass, but this is already checked before this function is called
assert isinstance(tool_breakpoint, ToolBreakpoint)
if tool_breakpoint.tool_name and tool_breakpoint.tool_name not in available_tool_names:
raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools")


def _create_pipeline_snapshot_from_chat_generator(
*, execution_context: "_ExecutionContext", agent_name: str | None = None, break_point: AgentBreakpoint | None = None
) -> PipelineSnapshot:
"""
Create a pipeline snapshot when a chat generator breakpoint is raised or an exception during execution occurs.

:param execution_context: The current execution context of the agent.
:param agent_name: The name of the agent component if present in a pipeline.
:param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one.
A scenario where a new breakpoint is created is when an exception occurs during chat generation and we want to
capture the state at that point.
:returns:
A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception.
"""
if break_point is None:
agent_breakpoint = AgentBreakpoint(
agent_name=agent_name or "agent",
break_point=Breakpoint(
component_name="chat_generator",
visit_count=execution_context.component_visits["chat_generator"],
snapshot_file_path=_get_output_dir("pipeline_snapshot"),
),
)
else:
agent_breakpoint = break_point

agent_snapshot = _create_agent_snapshot(
component_visits=execution_context.component_visits,
agent_breakpoint=agent_breakpoint,
component_inputs={
"chat_generator": {
"messages": execution_context.state.data["messages"],
**execution_context.chat_generator_inputs,
},
"tool_invoker": {"messages": [], "state": execution_context.state, **execution_context.tool_invoker_inputs},
},
)

return PipelineSnapshot._from_agent_snapshot(agent_snapshot=agent_snapshot)


def _create_pipeline_snapshot_from_tool_invoker(
*,
execution_context: "_ExecutionContext",
tool_name: str | None = None,
agent_name: str | None = None,
break_point: AgentBreakpoint | None = None,
) -> PipelineSnapshot:
"""
Create a pipeline snapshot when a tool invoker breakpoint is raised or an exception during execution occurs.

:param execution_context: The current execution context of the agent.
:param tool_name: The name of the tool that triggered the breakpoint, if available.
:param agent_name: The name of the agent component if present in a pipeline.
:param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one.
A scenario where a new breakpoint is created is when an exception occurs during tool execution and we want to
capture the state at that point.
:returns:
A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception.
"""
if break_point is None:
agent_breakpoint = AgentBreakpoint(
agent_name=agent_name or "agent",
break_point=ToolBreakpoint(
component_name="tool_invoker",
visit_count=execution_context.component_visits["tool_invoker"],
tool_name=tool_name,
snapshot_file_path=_get_output_dir("pipeline_snapshot"),
),
)
else:
agent_breakpoint = break_point

messages = execution_context.state.data["messages"]
agent_snapshot = _create_agent_snapshot(
component_visits=execution_context.component_visits,
agent_breakpoint=agent_breakpoint,
component_inputs={
"chat_generator": {"messages": messages[:-1], **execution_context.chat_generator_inputs},
"tool_invoker": {
"messages": messages[-1:], # tool invoker consumes last msg from the chat_generator, contains tool call
"state": execution_context.state,
**execution_context.tool_invoker_inputs,
},
},
)

# Create an empty pipeline snapshot
return PipelineSnapshot._from_agent_snapshot(agent_snapshot=agent_snapshot)


def _should_trigger_tool_invoker_breakpoint(break_point: ToolBreakpoint, llm_messages: list[ChatMessage]) -> bool:
"""
Determine if a tool invoker breakpoint should be triggered based on the provided ToolBreakpoint and LLM messages.

:param break_point: The ToolBreakpoint to check against.
:param llm_messages: A list of ChatMessage objects representing the LLM messages.
:returns:
True if the breakpoint should be triggered, False otherwise.
"""
# Check if we should break for this specific tool or all tools
if break_point.tool_name is None:
# Break for any tool call
return any(msg.tool_call for msg in llm_messages)

# Break only for the specific tool
return any(tc.tool_name == break_point.tool_name for msg in llm_messages for tc in msg.tool_calls or [])
Loading
Loading