Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions compaction_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

"""
Example: Automatic context compaction with SummarizationCompactionTool

This example shows how to add automatic chat history summarization to a Haystack
Agent using the new SummarizationCompactionTool and the condition-triggered tool
mechanism.

The agent has a simple calculator tool and is given a pre-existing conversation
history with alternating user/assistant messages. Once the non-system message count
exceeds `max_messages`, the compaction tool automatically summarizes the oldest
messages into a single concise entry — without the agent having to decide to do it.

Run with:
OPENAI_API_KEY=<your-key> python compaction_example.py
"""

from typing import Annotated

from haystack.components.agents import Agent
from haystack.components.agents.compaction import SummarizationCompactionTool
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.tools import tool


@tool
def add(a: Annotated[float, "First number"], b: Annotated[float, "Second number"]) -> float:
"""Add two numbers."""
return a + b


# A dedicated generator for the compaction summarizer (can be the same model or a cheaper/faster one — here we reuse
# the same model for simplicity).
summarizer_generator = OpenAIChatGenerator(model="gpt-4o-mini")

# The compaction tool fires automatically before each LLM call once the non-system message count exceeds max_messages.
# It summarizes the oldest half of the history and keeps the most recent messages verbatim.
compaction = SummarizationCompactionTool(
chat_generator=summarizer_generator,
max_messages=3, # low threshold so it triggers quickly in this demo
)

agent = Agent(
chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"),
tools=[add, compaction],
system_prompt="You are a helpful assistant that can perform calculations.",
)

# Pre-existing conversation history with alternating user/assistant messages.
# This simulates a session that has already been running for a while and will
# push the message count past the compaction threshold on the next agent call.
messages = [
ChatMessage.from_user("What is quantum mechanics?"),
ChatMessage.from_assistant(
"Quantum mechanics is a fundamental theory in physics that describes the behavior of matter and energy at "
"the smallest scales."
),
ChatMessage.from_user("Can you explain it in simple terms?"),
ChatMessage.from_assistant(
"Sure! Quantum mechanics is like a set of rules that govern how tiny particles, like electrons and photons,"
" behave. It tells us that these particles can exist in multiple states at once (superposition) and can be "
"connected in ways that seem to defy classical physics (entanglement). It's a fascinating and complex field "
"that has led to many technological advancements, like semiconductors and quantum computing."
),
ChatMessage.from_user("Thanks! Now, can you do some math for me? What's 3 + 4?"),
]

print(f"History length before agent call: {len(messages)} messages")

result = agent.run(messages=messages)

print(f"History length after agent call: {len(result['messages'])} messages")
print(f"\nAgent reply: {result['last_message'].text}")
3 changes: 2 additions & 1 deletion haystack/components/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from lazy_imports import LazyImporter

_import_structure = {"agent": ["Agent"], "state": ["State"]}
_import_structure = {"agent": ["Agent"], "compaction": ["SummarizationCompactionTool"], "state": ["State"]}

if TYPE_CHECKING:
from .agent import Agent as Agent
from .compaction import SummarizationCompactionTool as SummarizationCompactionTool
from .state import State as State

else:
Expand Down
79 changes: 77 additions & 2 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
import re
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Literal, cast

from haystack import Pipeline, component, logging, tracing
Expand Down Expand Up @@ -98,6 +98,7 @@ class _ExecutionContext:
skip_chat_generator: bool = False
confirmation_strategy_context: dict[str, Any] | None = None
tool_execution_decisions: list[ToolExecutionDecision] | None = None
pending_condition_tools: dict[str, int] = field(default_factory=dict)


@component
Expand Down Expand Up @@ -281,7 +282,10 @@ def __init__(
"The Agent component requires a chat generator that supports tools when tools are provided."
)

valid_exits = ["text"] + [tool.name for tool in flatten_tools_or_toolsets(tools)]
all_tools_flat = flatten_tools_or_toolsets(tools)
self._condition_tools: list[Tool] = [t for t in all_tools_flat if t.condition is not None]

valid_exits = ["text"] + [tool.name for tool in all_tools_flat]
if exit_conditions is None:
exit_conditions = ["text"]
if not all(condition in valid_exits for condition in exit_conditions):
Expand All @@ -301,6 +305,10 @@ def __init__(
resolved_state_schema = dict(self._state_schema)
if resolved_state_schema.get("messages") is None:
resolved_state_schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists}
if resolved_state_schema.get("tool_call_counts") is None:
resolved_state_schema["tool_call_counts"] = {"type": dict, "handler": replace_values}
if resolved_state_schema.get("step") is None:
resolved_state_schema["step"] = {"type": int, "handler": replace_values}
self.state_schema = resolved_state_schema

self.chat_generator = chat_generator
Expand Down Expand Up @@ -826,12 +834,14 @@ def run( # noqa: PLR0915
span.set_content_tag("haystack.agent.input", agent_inputs)

while exe_context.counter < self.max_agent_steps:
exe_context.state.set("step", exe_context.counter)
# We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
if exe_context.skip_chat_generator:
llm_messages = exe_context.state.get("messages", [])[-1:]
# Set to False so the next iteration will call the chat generator
exe_context.skip_chat_generator = False
else:
self._process_condition_tools(exe_context.state, exe_context)
try:
result = Pipeline._run_component(
component_name="chat_generator",
Expand Down Expand Up @@ -950,6 +960,14 @@ def run( # noqa: PLR0915

tool_messages = tool_invoker_result["tool_messages"]
exe_context.state = tool_invoker_result["state"]

# Update tool_call_counts and clear pending re-prompts for successfully called tools
for msg in tool_messages:
if msg.tool_call_result and not msg.tool_call_result.error:
called_name = msg.tool_call_result.origin.tool_name
self._increment_tool_call_count(exe_context.state, called_name)
exe_context.pending_condition_tools.pop(called_name, None)

exe_context.state.set("messages", tool_messages)

# Check if any LLM message's tool call name matches an exit condition
Expand Down Expand Up @@ -1060,12 +1078,14 @@ async def run_async( # noqa: PLR0915
span.set_content_tag("haystack.agent.input", agent_inputs)

while exe_context.counter < self.max_agent_steps:
exe_context.state.set("step", exe_context.counter)
# We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
if exe_context.skip_chat_generator:
llm_messages = exe_context.state.get("messages", [])[-1:]
# Set to False so the next iteration will call the chat generator
exe_context.skip_chat_generator = False
else:
self._process_condition_tools(exe_context.state, exe_context)
try:
result = await AsyncPipeline._run_component_async(
component_name="chat_generator",
Expand Down Expand Up @@ -1184,6 +1204,14 @@ async def run_async( # noqa: PLR0915

tool_messages = tool_invoker_result["tool_messages"]
exe_context.state = tool_invoker_result["state"]

# Update tool_call_counts and clear pending re-prompts for successfully called tools
for msg in tool_messages:
if msg.tool_call_result and not msg.tool_call_result.error:
called_name = msg.tool_call_result.origin.tool_name
self._increment_tool_call_count(exe_context.state, called_name)
exe_context.pending_condition_tools.pop(called_name, None)

exe_context.state.set("messages", tool_messages)

# Check if any LLM message's tool call name matches an exit condition
Expand All @@ -1207,6 +1235,53 @@ async def run_async( # noqa: PLR0915
result["last_message"] = msgs[-1]
return result

def _process_condition_tools(self, state: State, exe_context: _ExecutionContext) -> None:
"""
Check and process any tools whose condition is satisfied by the current state.

Called before each LLM invocation. Depending on the tool's parameter schema:
- Empty parameters: auto-invoke directly (e.g. compaction tools with all args from state).
- Non-empty parameters: inject a system message re-prompting the LLM to call the tool,
since the LLM needs to supply arguments. Re-prompts up to 3 times per condition trigger.

:param state: The current agent state.
:param exe_context: The current execution context (tracks re-prompt attempts).
"""
for tool in self._condition_tools:
attempts = exe_context.pending_condition_tools.get(tool.name, 0)

# New condition trigger
if attempts == 0 and tool.condition(state):
if not tool.parameters.get("properties"):
# Auto-invoke: no LLM-facing params (e.g. compaction)
final_args = ToolInvoker._inject_state_args(tool, {}, state)
result = tool.invoke(**final_args)
ToolInvoker._merge_tool_outputs(tool, result, state)
self._increment_tool_call_count(state, tool.name)
continue
# Re-prompt: LLM must supply args
exe_context.pending_condition_tools[tool.name] = 1
state.set(
"messages",
[ChatMessage.from_system(f"You must call the '{tool.name}' tool now. {tool.description}")],
)
continue

# Ongoing re-prompt (tool not yet called, under retry limit)
if 0 < attempts < 3:
exe_context.pending_condition_tools[tool.name] = attempts + 1
state.set(
"messages",
[ChatMessage.from_system(f"You must call the '{tool.name}' tool now. {tool.description}")],
)

@staticmethod
def _increment_tool_call_count(state: State, tool_name: str) -> None:
"""Increment the call count for a tool in the state."""
counts = state.get("tool_call_counts") or {}
counts[tool_name] = counts.get(tool_name, 0) + 1
state.set("tool_call_counts", counts)

def _check_exit_conditions(self, llm_messages: list[ChatMessage], tool_messages: list[ChatMessage]) -> bool:
"""
Check if any of the LLM messages' tool calls match an exit condition and if there are no errors.
Expand Down
128 changes: 128 additions & 0 deletions haystack/components/agents/compaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, Any

from haystack.components.agents.state.state_utils import replace_values
from haystack.core.serialization import component_from_dict, component_to_dict, generate_qualified_class_name
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.tools import Tool

if TYPE_CHECKING:
from haystack.components.agents.state import State
from haystack.components.generators.chat.types import ChatGenerator


_COMPACTION_PARAMETERS: dict[str, Any] = {"type": "object", "properties": {}}


class SummarizationCompactionTool(Tool):
"""
A tool that summarizes the oldest messages once the history exceeds a length threshold.

The condition fires before each LLM call. When the number of non-system messages exceeds
`max_messages`, the older portion is summarized by `chat_generator` into a single user message.
The most recent `max_messages // 2` non-system messages are kept verbatim so the LLM retains
immediate context. The system message (if any) is always preserved at the front.

Because this tool is condition-triggered it is never exposed to the LLM and cannot be called
by the model directly.

Usage example:
```python
from haystack.components.agents import Agent
from haystack.components.agents.compaction import SummarizationCompactionTool
from haystack.components.generators.chat import OpenAIChatGenerator

compaction = SummarizationCompactionTool(
chat_generator=OpenAIChatGenerator(),
max_messages=20,
)
agent = Agent(chat_generator=OpenAIChatGenerator(), tools=[compaction, ...])
```
"""

def __init__(
self, chat_generator: "ChatGenerator", max_messages: int = 10, summarization_prompt: str | None = None
) -> None:
"""
Initialize the SummarizationCompactionTool.

:param chat_generator: A chat generator used to produce the summary. Called synchronously.
:param max_messages: Number of non-system messages that triggers compaction. Defaults to 10.
:param summarization_prompt: Prompt sent to the chat generator to request a summary. The
placeholder ``{messages}`` is replaced with the formatted conversation excerpt to summarize.
If not provided, a sensible default is used.
"""
self.chat_generator = chat_generator
self.max_messages = max_messages
self.summarization_prompt = summarization_prompt

_prompt = summarization_prompt or (
"Summarize the following conversation concisely, preserving all important context, "
"decisions, and information that would be needed to continue the conversation:\n\n{messages}"
)

def _condition(state: "State") -> bool:
messages: list[ChatMessage] = state.get("messages") or []
non_system = [m for m in messages if not m.is_from(ChatRole.SYSTEM)]
return len(non_system) > max_messages

def _summarize(messages: list[ChatMessage]) -> dict[str, Any]:
system_msgs = [m for m in messages if m.is_from(ChatRole.SYSTEM)]
non_system = [m for m in messages if not m.is_from(ChatRole.SYSTEM)]

keep_count = max_messages // 2
to_summarize = non_system[:-keep_count] if keep_count else non_system
to_keep = non_system[-keep_count:] if keep_count else []

messages_text = "\n".join(f"{m.role.value}: {m.text}" for m in to_summarize if m.text)
prompt = _prompt.format(messages=messages_text)

result = chat_generator.run(messages=[ChatMessage.from_user(prompt)])
summary_text = result["replies"][0].text or ""

summary_msg = ChatMessage.from_user(f"[Summary of previous conversation: {summary_text}]")
return {"messages": system_msgs + [summary_msg] + to_keep}

super().__init__(
name="summarization_compaction",
description="Summarizes the oldest chat messages to reduce context length.",
parameters=_COMPACTION_PARAMETERS,
function=_summarize,
inputs_from_state={"messages": "messages"},
outputs_to_state={"messages": {"source": "messages", "handler": replace_values}},
condition=_condition,
)

def to_dict(self) -> dict[str, Any]:
"""
Serializes the tool to a dictionary.

:returns: Dictionary with serialized data.
"""
serialized: dict[str, Any] = {
"chat_generator": component_to_dict(obj=self.chat_generator, name="chat_generator"),
"max_messages": self.max_messages,
"summarization_prompt": self.summarization_prompt,
}
return {"type": generate_qualified_class_name(type(self)), "data": serialized}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SummarizationCompactionTool":
"""
Deserializes the tool from a dictionary.

:param data: Dictionary to deserialize from.
:returns: Deserialized tool.
"""
from haystack.core.serialization import import_class_by_name

inner_data = data["data"]
generator_data = inner_data["chat_generator"]
generator_class = import_class_by_name(generator_data["type"])
inner_data["chat_generator"] = component_from_dict(
cls=generator_class, data=generator_data, name="chat_generator"
)
return cls(**inner_data)
Loading
Loading