Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 86 additions & 43 deletions src/bedrock_agentcore/memory/integrations/strands/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def __init__(
session = boto_session or boto3.Session(region_name=region_name)
self.has_existing_agent = False

# Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp)
self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = []
# Batching support - stores pre-processed messages: (session_id, agent_id, messages, is_blob, timestamp)
self._message_buffer: list[tuple[str, Optional[str], list[tuple[str, str]], bool, datetime]] = []
self._message_lock = threading.Lock()

# Agent state buffering - stores all agent state updates: (session_id, agent)
Expand Down Expand Up @@ -482,6 +482,11 @@ def create_message(

is_blob = self.converter.exceeds_conversational_limit(messages[0])

# Build agent_id metadata for multi-agent message tagging
agent_metadata = None
if agent_id:
agent_metadata = {AGENT_ID_KEY: {"stringValue": agent_id}}

# Parse the original timestamp and use it as desired timestamp
original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00"))
monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp)
Expand All @@ -490,7 +495,7 @@ def create_message(
# Buffer the pre-processed message
should_flush = False
with self._message_lock:
self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp))
self._message_buffer.append((session_id, agent_id, messages, is_blob, monotonic_timestamp))
should_flush = len(self._message_buffer) >= self.config.batch_size

# Flush only messages outside the lock to prevent deadlock
Expand All @@ -508,17 +513,19 @@ def create_message(
session_id=session_id,
messages=messages,
event_timestamp=monotonic_timestamp,
metadata=agent_metadata,
)
else:
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=session_id,
payload=[
{"blob": json.dumps(messages[0])},
],
eventTimestamp=monotonic_timestamp,
)
create_event_kwargs: dict[str, Any] = {
"memoryId": self.config.memory_id,
"actorId": self.config.actor_id,
"sessionId": session_id,
"payload": [{"blob": json.dumps(messages[0])}],
"eventTimestamp": monotonic_timestamp,
}
if agent_metadata:
create_event_kwargs["metadata"] = agent_metadata
event = self.memory_client.gmdp_client.create_event(**create_event_kwargs)
logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id)
return event
except Exception as e:
Expand Down Expand Up @@ -598,12 +605,46 @@ def list_messages(
try:
max_results = (limit + offset) if limit else MAX_FETCH_ALL_RESULTS

events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
session_id=session_id,
max_results=max_results,
)
# Try filtering by agent_id first for multi-agent support
if agent_id:
agent_id_filter = [
EventMetadataFilter.build_expression(
left_operand=LeftExpression.build(AGENT_ID_KEY),
operator=OperatorType.EQUALS_TO,
right_operand=RightExpression.build(agent_id),
)
]
events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
session_id=session_id,
max_results=max_results,
event_metadata=agent_id_filter,
)

# Backward compatibility: if filtered query returns empty, retry without
# the agent_id filter. This handles sessions created before multi-agent
# metadata was added to message events.
if not events:
logger.debug(
"No events found with agent_id filter for agent %s, "
"falling back to unfiltered query for backward compatibility.",
agent_id,
)
events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
session_id=session_id,
max_results=max_results,
)
else:
events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
session_id=session_id,
max_results=max_results,
)

messages = self.converter.events_to_messages(events)
if self.config.filter_restored_tool_context:
messages = self._filter_restored_tool_context(messages)
Expand Down Expand Up @@ -756,11 +797,8 @@ def register_hooks(self, registry: HookRegistry, **kwargs) -> None:
@override
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
if self.has_existing_agent:
logger.warning(
"An Agent already exists in session %s. We currently support one agent per session.", self.session_id
)
else:
self.has_existing_agent = True
logger.info("Multiple agents registered in session %s.", self.session_id)
self.has_existing_agent = True
RepositorySessionManager.initialize(self, agent, **kwargs)

# endregion RepositorySessionManager overrides
Expand Down Expand Up @@ -789,44 +827,49 @@ def _flush_messages_only(self) -> list[dict[str, Any]]:
if not messages_to_send:
return []

# Group all messages by session_id, combining conversational and blob messages
# Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp}}
session_groups: dict[str, dict[str, Any]] = {}
# Group all messages by (session_id, agent_id), combining conversational and blob messages
# Structure: {(session_id, agent_id): {"payload": [...], "timestamp": latest_timestamp}}
session_groups: dict[tuple[str, Optional[str]], dict[str, Any]] = {}

for session_id, messages, is_blob, monotonic_timestamp in messages_to_send:
if session_id not in session_groups:
session_groups[session_id] = {"payload": [], "timestamp": monotonic_timestamp}
for session_id, agent_id, messages, is_blob, monotonic_timestamp in messages_to_send:
group_key = (session_id, agent_id)
if group_key not in session_groups:
session_groups[group_key] = {"payload": [], "timestamp": monotonic_timestamp}

if is_blob:
# Add blob messages to payload
for msg in messages:
session_groups[session_id]["payload"].append({"blob": json.dumps(msg)})
session_groups[group_key]["payload"].append({"blob": json.dumps(msg)})
else:
# Add conversational messages to payload
for text, role in messages:
session_groups[session_id]["payload"].append(
session_groups[group_key]["payload"].append(
{"conversational": {"content": {"text": text}, "role": role.upper()}}
)

# Use the latest timestamp for the combined event
if monotonic_timestamp > session_groups[session_id]["timestamp"]:
session_groups[session_id]["timestamp"] = monotonic_timestamp
if monotonic_timestamp > session_groups[group_key]["timestamp"]:
session_groups[group_key]["timestamp"] = monotonic_timestamp

results = []
try:
# Send one create_event per session_id with all messages (conversational + blob)
for session_id, group in session_groups.items():
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=session_id,
payload=group["payload"],
eventTimestamp=group["timestamp"],
)
# Send one create_event per (session_id, agent_id) with all messages
for (session_id, agent_id), group in session_groups.items():
create_event_kwargs: dict[str, Any] = {
"memoryId": self.config.memory_id,
"actorId": self.config.actor_id,
"sessionId": session_id,
"payload": group["payload"],
"eventTimestamp": group["timestamp"],
}
if agent_id:
create_event_kwargs["metadata"] = {AGENT_ID_KEY: {"stringValue": agent_id}}
event = self.memory_client.gmdp_client.create_event(**create_event_kwargs)
results.append(event)
logger.debug(
"Flushed batched event for session %s with %d messages: %s",
"Flushed batched event for session %s agent %s with %d messages: %s",
session_id,
agent_id,
len(group["payload"]),
event.get("eventId"),
)
Expand Down
Loading
Loading