Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 56 additions & 27 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .._async import run_async
from ..agent import Agent
from ..agent.base import AgentBase
from ..agent.state import AgentState
from ..hooks.events import (
AfterMultiAgentInvocationEvent,
Expand Down Expand Up @@ -65,16 +66,19 @@ class SwarmNode:
"""Represents a node (e.g. Agent) in the swarm."""

node_id: str
executor: Agent
executor: AgentBase
swarm: Optional["Swarm"] = None
_initial_messages: Messages = field(default_factory=list, init=False)
_initial_state: AgentState = field(default_factory=AgentState, init=False)

def __post_init__(self) -> None:
"""Capture initial executor state after initialization."""
# Deep copy the initial messages and state to preserve them
self._initial_messages = copy.deepcopy(self.executor.messages)
self._initial_state = AgentState(self.executor.state.get())
if hasattr(self.executor, "messages"):
self._initial_messages = copy.deepcopy(self.executor.messages)

if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"):
self._initial_state = AgentState(self.executor.state.get())

def __hash__(self) -> int:
"""Return hash for SwarmNode based on node_id."""
Expand All @@ -99,15 +103,20 @@ def reset_executor_state(self) -> None:

If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context.
"""
if self.swarm and self.swarm._interrupt_state.activated:
# Handle interrupt state restoration (Agent-specific)
if self.swarm and self.swarm._interrupt_state.activated and isinstance(self.executor, Agent):
context = self.swarm._interrupt_state.context[self.node_id]
self.executor.messages = context["messages"]
self.executor.state = AgentState(context["state"])
self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
return

self.executor.messages = copy.deepcopy(self._initial_messages)
self.executor.state = AgentState(self._initial_state.get())
# Reset to initial state (works with any AgentBase that has these attributes)
if hasattr(self.executor, "messages"):
self.executor.messages = copy.deepcopy(self._initial_messages)

if hasattr(self.executor, "state"):
self.executor.state = AgentState(self._initial_state.get())


@dataclass
Expand Down Expand Up @@ -232,9 +241,9 @@ class Swarm(MultiAgentBase):

def __init__(
self,
nodes: list[Agent],
nodes: list[AgentBase],
*,
entry_point: Agent | None = None,
entry_point: AgentBase | None = None,
max_handoffs: int = 20,
max_iterations: int = 20,
execution_timeout: float = 900.0,
Expand Down Expand Up @@ -458,19 +467,20 @@ async def _stream_with_timeout(
except asyncio.TimeoutError as err:
raise Exception(timeout_message) from err

def _setup_swarm(self, nodes: list[Agent]) -> None:
def _setup_swarm(self, nodes: list[AgentBase]) -> None:
"""Initialize swarm configuration."""
# Validate nodes before setup
self._validate_swarm(nodes)

# Validate agents have names and create SwarmNode objects
for i, node in enumerate(nodes):
if not node.name:
# Only access name if it exists (AgentBase protocol doesn't guarantee it)
node_name = getattr(node, "name", None)
if not node_name:
node_id = f"node_{i}"
node.name = node_id
logger.debug("node_id=<%s> | agent has no name, dynamically generating one", node_id)

node_id = str(node.name)
logger.debug("node_id=<%s> | agent has no name, using generated id", node_id)
else:
node_id = str(node_name)

# Ensure node IDs are unique
if node_id in self.nodes:
Expand All @@ -480,7 +490,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:

# Validate entry point if specified
if self.entry_point is not None:
entry_point_node_id = str(self.entry_point.name)
entry_point_node_id = str(getattr(self.entry_point, "name", None))
if (
entry_point_node_id not in self.nodes
or self.nodes[entry_point_node_id].executor is not self.entry_point
Expand All @@ -500,7 +510,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
first_node = next(iter(self.nodes.keys()))
logger.debug("entry_point=<%s> | using first node as entry point", first_node)

def _validate_swarm(self, nodes: list[Agent]) -> None:
def _validate_swarm(self, nodes: list[AgentBase]) -> None:
"""Validate swarm structure and nodes."""
# Check for duplicate object instances
seen_instances = set()
Expand All @@ -509,18 +519,31 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
seen_instances.add(id(node))

# Check for session persistence
if node._session_manager is not None:
# Check for session persistence (only Agent has _session_manager attribute)
if isinstance(node, Agent) and node._session_manager is not None:
raise ValueError("Session persistence is not supported for Swarm agents yet.")

def _inject_swarm_tools(self) -> None:
"""Add swarm coordination tools to each agent."""
"""Add swarm coordination tools to each agent.

Note: Only Agent instances can receive swarm tools. AgentBase implementations
without tool_registry will not have handoff capabilities.
"""
# Create tool functions with proper closures
swarm_tools = [
self._create_handoff_tool(),
]

injected_count = 0
for node in self.nodes.values():
# Only Agent (not generic AgentBase) has tool_registry attribute
if not isinstance(node.executor, Agent):
logger.debug(
"node_id=<%s> | skipping tool injection for non-Agent node",
node.node_id,
)
continue

# Check for existing tools with conflicting names
existing_tools = node.executor.tool_registry.registry
conflicting_tools = []
Expand All @@ -536,11 +559,13 @@ def _inject_swarm_tools(self) -> None:

# Use the agent's tool registry to process and register the tools
node.executor.tool_registry.process_tools(swarm_tools)
injected_count += 1

logger.debug(
"tool_count=<%d>, node_count=<%d> | injected coordination tools into agents",
"tool_count=<%d>, node_count=<%d>, injected_count=<%d> | injected coordination tools",
len(swarm_tools),
len(self.nodes),
injected_count,
)

def _create_handoff_tool(self) -> Callable[..., Any]:
Expand Down Expand Up @@ -692,12 +717,14 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M
logger.debug("node=<%s> | node interrupted", node.node_id)
self.state.completion_status = Status.INTERRUPTED

self._interrupt_state.context[node.node_id] = {
"activated": node.executor._interrupt_state.activated,
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
}
# Only Agent (not generic AgentBase) has _interrupt_state, state, and messages attributes
if isinstance(node.executor, Agent):
self._interrupt_state.context[node.node_id] = {
"activated": node.executor._interrupt_state.activated,
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
}

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()
Expand Down Expand Up @@ -1037,5 +1064,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None:

def _initial_node(self) -> SwarmNode:
if self.entry_point:
return self.nodes[str(self.entry_point.name)]
entry_point_name = getattr(self.entry_point, "name", None)
if entry_point_name and str(entry_point_name) in self.nodes:
return self.nodes[str(entry_point_name)]
return next(iter(self.nodes.values())) # First SwarmNode
Loading