Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 57 additions & 101 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, ChatRole, StreamingCallbackT, select_streaming_callback
from haystack.human_in_the_loop.strategies import (
_deserialize_confirmation_strategies,
_process_confirmation_strategies,
_process_confirmation_strategies_async,
)
Expand All @@ -41,6 +42,8 @@

# Regex to detect the Jinja2 chat template syntax
_JINJA2_CHAT_TEMPLATE_RE = re.compile(r"\{%\s*message\s")
# Regex to extract the role from a Jinja2 message block, e.g. {% message role="user" %}
_JINJA2_MESSAGE_ROLE_RE = re.compile(r'\{%\s*message\s+role\s*=\s*["\'](\w+)["\']')


def _get_run_method_params(instance: "Agent") -> set[str]:
Expand All @@ -49,6 +52,23 @@ def _get_run_method_params(instance: "Agent") -> set[str]:
return {name for name, p in sig.parameters.items() if p.kind != inspect.Parameter.VAR_KEYWORD}


def _validate_prompt_message_blocks(user_prompt: str | None, system_prompt: str | None) -> None:
"""Validate that user_prompt and system_prompt define exactly one message block with the correct role."""
if user_prompt is not None:
roles = _JINJA2_MESSAGE_ROLE_RE.findall(user_prompt)
if len(roles) > 1:
raise ValueError(f"user_prompt must define exactly one message block, found {len(roles)}.")
if roles and roles[0] != "user":
raise ValueError(f"user_prompt message block must have role 'user', found role '{roles[0]}'.")

if system_prompt is not None and _JINJA2_CHAT_TEMPLATE_RE.search(system_prompt):
roles = _JINJA2_MESSAGE_ROLE_RE.findall(system_prompt)
if len(roles) > 1:
raise ValueError(f"system_prompt must define exactly one message block, found {len(roles)}.")
if roles and roles[0] != "system":
raise ValueError(f"system_prompt message block must have role 'system', found role '{roles[0]}'.")


@dataclass(kw_only=True)
class _ExecutionContext:
"""
Expand Down Expand Up @@ -230,37 +250,29 @@ def __init__(
:raises ValueError: If the exit_conditions are not valid.
:raises ValueError: If any `user_prompt` variable overlaps with the `state_schema` or `run` method parameters.
"""
# Check if chat_generator supports tools parameter
chat_generator_run_method = inspect.signature(chat_generator.run)
self._chat_generator_supports_tools: bool = "tools" in chat_generator_run_method.parameters
# --- Validation ---
self._chat_generator_supports_tools: bool = "tools" in inspect.signature(chat_generator.run).parameters
if tools and not self._chat_generator_supports_tools:
raise TypeError(
f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
"The Agent component requires a chat generator that supports tools when tools are provided."
)

valid_exits = ["text"] + [tool.name for tool in flatten_tools_or_toolsets(tools)]
if exit_conditions is None:
exit_conditions = ["text"]
valid_exits = ["text"] + [tool.name for tool in flatten_tools_or_toolsets(tools)]
if not all(condition in valid_exits for condition in exit_conditions):
raise ValueError(
f"Invalid exit conditions provided: {exit_conditions}. "
f"Valid exit conditions must be a subset of {valid_exits}. "
"Ensure that each exit condition corresponds to either 'text' or a valid tool name."
)

# Validate state schema if provided
if state_schema is not None:
_validate_schema(state_schema)
self._state_schema = state_schema or {}

# Initialize state schema
# shallow copy is sufficient: we only add a top-level "messages" key, never mutate nested values
resolved_state_schema = dict(self._state_schema)
if resolved_state_schema.get("messages") is None:
resolved_state_schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists}
self.state_schema = resolved_state_schema
_validate_prompt_message_blocks(user_prompt, system_prompt)

# --- Attributes ---
self.chat_generator = chat_generator
self.tools = tools or []
self.system_prompt = system_prompt
Expand All @@ -270,48 +282,51 @@ def __init__(
self.max_agent_steps = max_agent_steps
self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
self.streaming_callback = streaming_callback
self.tool_invoker_kwargs = tool_invoker_kwargs
self._confirmation_strategies = confirmation_strategies or {}
self._is_warmed_up = False

# Set input and output types for the component based on the State schema
# --- State schema ---
# shallow copy is sufficient: we only add a top-level "messages" key, never mutate nested values
self._state_schema = state_schema or {}
self.state_schema = dict(self._state_schema)
if self.state_schema.get("messages") is None:
self.state_schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists}

# --- Component I/O ---
self._run_method_params = _get_run_method_params(self)
output_types = {"last_message": ChatMessage}
for param, config in self.state_schema.items():
output_types[param] = config["type"]
# Skip setting input types for parameters that are already in the run method
if param in self._run_method_params:
continue
component.set_input_type(self, name=param, type=config["type"], default=None)
if param not in self._run_method_params:
component.set_input_type(self, name=param, type=config["type"], default=None)
component.set_output_types(self, **output_types)

# required_variables is initially set to [] and populated later by _register_prompt_variables
# --- Prompt builders ---
# required_variables starts empty and is populated by _register_prompt_variables once
# builder.variables are known
self._user_chat_prompt_builder = (
ChatPromptBuilder(template=user_prompt, required_variables=[]) if user_prompt is not None else None
)
# Only create a system prompt builder when the prompt uses Jinja2 message syntax
self._system_chat_prompt_builder: ChatPromptBuilder | None = None
if system_prompt is not None and _JINJA2_CHAT_TEMPLATE_RE.search(system_prompt):
self._system_chat_prompt_builder = ChatPromptBuilder(template=system_prompt, required_variables=[])

self._register_prompt_variables()

self.tool_invoker_kwargs = tool_invoker_kwargs
# --- Tool invoker ---
self._tool_invoker = None
if self.tools:
resolved_tool_invoker_kwargs = {
"tools": self.tools,
"raise_on_failure": self.raise_on_tool_invocation_failure,
**(tool_invoker_kwargs or {}),
}
self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs)
self._tool_invoker = ToolInvoker(
tools=self.tools,
raise_on_failure=self.raise_on_tool_invocation_failure,
**(self.tool_invoker_kwargs or {}),
)
elif type(self).__name__ == "Agent":
logger.warning(
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
"responses. To enable tool usage, pass tools directly to the Agent, not to the chat_generator."
)

self._confirmation_strategies = confirmation_strategies or {}

self._is_warmed_up = False

def _register_prompt_variables(self) -> None:
"""
Collect variables from both Chat Prompt Builders and register Agent inputs.
Expand Down Expand Up @@ -426,16 +441,9 @@ def from_dict(cls, data: dict[str, Any]) -> "Agent":
deserialize_tools_or_toolset_inplace(init_params, key="tools")

if init_params.get("confirmation_strategies") is not None:
restored: dict[str | tuple[str, ...], Any] = {}
for raw_key in init_params["confirmation_strategies"].keys():
deserialize_component_inplace(init_params["confirmation_strategies"], key=raw_key)
strategy = init_params["confirmation_strategies"][raw_key]
if isinstance(raw_key, list):
key = tuple(raw_key)
else:
key = raw_key
restored[key] = strategy
init_params["confirmation_strategies"] = restored
init_params["confirmation_strategies"] = _deserialize_confirmation_strategies(
init_params["confirmation_strategies"]
)

return default_from_dict(cls, data)

Expand Down Expand Up @@ -464,8 +472,6 @@ def _initialize_fresh_execution(
streaming_callback: StreamingCallbackT | None,
requires_async: bool,
*,
system_prompt: str | None = None,
user_prompt: str | None = None,
generation_kwargs: dict[str, Any] | None = None,
tools: ToolsType | list[str] | None = None,
confirmation_strategy_context: dict[str, Any] | None = None,
Expand All @@ -477,9 +483,6 @@ def _initialize_fresh_execution(
:param messages: List of ChatMessage objects to start the agent with.
:param streaming_callback: Optional callback for streaming responses.
:param requires_async: Whether the agent run requires asynchronous execution.
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
:param user_prompt: User prompt for the agent. If provided, it overrides the default user prompt and is
appended to the messages provided at runtime.
:param generation_kwargs: Additional keyword arguments for chat generator. These parameters will
override the parameters passed during component initialization.
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
Expand All @@ -488,57 +491,24 @@ def _initialize_fresh_execution(
to confirmation strategies.
:param kwargs: Additional data to pass to the State used by the Agent.
"""
user_prompt = user_prompt or self.user_prompt
system_prompt = system_prompt or self.system_prompt
messages = messages or []

if user_prompt is not None:
if self._user_chat_prompt_builder is None:
raise ValueError(
"user_prompt is provided but the ChatPromptBuilder is not initialized. "
"Please make sure a user_prompt is provided at initialization time."
)

# Only forward the prompt kwargs to the prompt builder
prompt_kwargs = {var: kwargs[var] for var in self._user_chat_prompt_builder.variables if var in kwargs}
user_messages = self._user_chat_prompt_builder.run(template=user_prompt, **prompt_kwargs)["prompt"]
if len(user_messages) != 1:
raise ValueError(
f"user_prompt must render to exactly one user message. Got {len(user_messages)} messages."
)
if not user_messages[0].is_from(ChatRole.USER):
raise ValueError(
f"user_prompt must render to a user message. Got a message with role {user_messages[0].role}."
)
if self.user_prompt is not None:
prompt_kwargs = {var: kwargs[var] for var in self._user_chat_prompt_builder.variables if var in kwargs} # type: ignore[union-attr]
user_messages = self._user_chat_prompt_builder.run(template=self.user_prompt, **prompt_kwargs)["prompt"] # type: ignore[union-attr]
messages = messages + user_messages

if system_prompt is not None:
if _JINJA2_CHAT_TEMPLATE_RE.search(system_prompt):
if self._system_chat_prompt_builder is None:
raise ValueError(
"system_prompt contains Jinja2 template syntax but no system prompt builder is initialized. "
"Please make sure a system_prompt with Jinja2 template syntax is provided at initialization "
"time."
)

if self.system_prompt is not None:
if self._system_chat_prompt_builder is not None:
prompt_kwargs = {
var: kwargs[var] for var in self._system_chat_prompt_builder.variables if var in kwargs
}
system_messages = self._system_chat_prompt_builder.run(template=system_prompt, **prompt_kwargs)[
system_messages = self._system_chat_prompt_builder.run(template=self.system_prompt, **prompt_kwargs)[
"prompt"
]
if len(system_messages) != 1:
raise ValueError(
f"system_prompt must render to exactly one system message. Got {len(system_messages)} messages."
)
if not system_messages[0].is_from(ChatRole.SYSTEM):
raise ValueError(
"system_prompt must render to a system message. "
f"Got a message with role {system_messages[0].role}."
)
messages = system_messages + messages
else:
messages = [ChatMessage.from_system(system_prompt)] + messages
messages = [ChatMessage.from_system(self.system_prompt)] + messages

if all(m.is_from(ChatRole.SYSTEM) for m in messages):
logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
Expand Down Expand Up @@ -619,8 +589,6 @@ def run( # noqa: PLR0915
streaming_callback: StreamingCallbackT | None = None,
*,
generation_kwargs: dict[str, Any] | None = None,
system_prompt: str | None = None,
user_prompt: str | None = None,
tools: ToolsType | list[str] | None = None,
confirmation_strategy_context: dict[str, Any] | None = None,
**kwargs: Any,
Expand All @@ -633,9 +601,6 @@ def run( # noqa: PLR0915
The same callback can be configured to emit tool results when a tool is called.
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will
override the parameters passed during component initialization.
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
:param user_prompt: User prompt for the agent. If provided, it overrides the default user prompt and is
appended to the messages provided at runtime.
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
When passing tool names, tools are selected from the Agent's originally configured tools.
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
Expand All @@ -658,8 +623,6 @@ def run( # noqa: PLR0915
messages=messages,
streaming_callback=streaming_callback,
requires_async=False,
system_prompt=system_prompt,
user_prompt=user_prompt,
generation_kwargs=generation_kwargs,
tools=tools,
confirmation_strategy_context=confirmation_strategy_context,
Expand Down Expand Up @@ -690,8 +653,6 @@ async def run_async( # noqa: PLR0915
streaming_callback: StreamingCallbackT | None = None,
*,
generation_kwargs: dict[str, Any] | None = None,
system_prompt: str | None = None,
user_prompt: str | None = None,
tools: ToolsType | list[str] | None = None,
confirmation_strategy_context: dict[str, Any] | None = None,
**kwargs: Any,
Expand All @@ -708,9 +669,6 @@ async def run_async( # noqa: PLR0915
LLM. The same callback can be configured to emit tool results when a tool is called.
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will
override the parameters passed during component initialization.
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
:param user_prompt: User prompt for the agent. If provided, it overrides the default user prompt and is
appended to the messages provided at runtime.
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
:param kwargs: Additional data to pass to the State schema used by the Agent.
The keys must match the schema defined in the Agent's `state_schema`.
Expand All @@ -734,8 +692,6 @@ async def run_async( # noqa: PLR0915
messages=messages,
streaming_callback=streaming_callback,
requires_async=True,
system_prompt=system_prompt,
user_prompt=user_prompt,
tools=tools,
generation_kwargs=generation_kwargs,
confirmation_strategy_context=confirmation_strategy_context,
Expand Down
16 changes: 16 additions & 0 deletions haystack/human_in_the_loop/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,19 @@ def _update_chat_history(
insertion_point = max(last_user_idx, last_tool_idx)

return chat_history[: insertion_point + 1] + rejection_messages + tool_call_and_explanation_messages


def _deserialize_confirmation_strategies(data: dict[str, Any]) -> dict[str | tuple[str, ...], ConfirmationStrategy]:
"""
Deserialize a confirmation strategies dictionary from its serialized form.

Deserializes each strategy component in-place and converts any list keys back to tuples,
since JSON serializes tuple keys as lists.

:param data: Raw dictionary of serialized confirmation strategies, keyed by tool name(s).
:returns: Deserialized confirmation strategies with proper key types.
"""
for raw_key in list(data):
deserialize_component_inplace(data, key=raw_key)

return {(tuple(raw_key) if isinstance(raw_key, list) else raw_key): strategy for raw_key, strategy in data.items()}
Loading
Loading