Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def _init_structure_tables(self):

conn.commit()

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items( # type: ignore[override]
self, items: list[TResponseInputItem],
wrapper: Any = None,
) -> None:
"""Add items to the session.

Args:
Expand Down Expand Up @@ -156,9 +159,11 @@ def _add_items_sync():

await asyncio.to_thread(_add_items_sync)

async def get_items(
async def get_items( # type: ignore[override]
self,
limit: int | None = None,
wrapper: Any = None,
Comment on lines 164 to +165
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve Session positional order in AdvancedSQLite get_items

The new Session.get_items contract adds wrapper as the second optional argument, but this override keeps branch_id in that slot and suppresses the incompatibility with # type: ignore[override]. Any caller using the protocol positionally (e.g. session.get_items(limit, wrapper)) will pass the wrapper object as branch_id, leading to incorrect branch lookup and broken context propagation for this built-in session implementation.

Useful? React with 👍 / 👎.

*,
branch_id: str | None = None,
) -> list[TResponseInputItem]:
"""Get items from current or specified branch.
Expand Down
27 changes: 22 additions & 5 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import cast
from typing import TYPE_CHECKING, Any, cast

import aiosqlite

from ...items import TResponseInputItem
from ...memory import SessionABC
from ...memory.session_settings import SessionSettings

if TYPE_CHECKING:
from ...run_context import RunContextWrapper


class AsyncSQLiteSession(SessionABC):
"""Async SQLite-based implementation of session storage.
Expand Down Expand Up @@ -102,7 +105,11 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]:
conn = await self._get_connection()
yield conn

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -150,7 +157,11 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -186,7 +197,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:

await conn.commit()

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down Expand Up @@ -220,7 +234,10 @@ async def pop_item(self) -> TResponseInputItem | None:

return None

async def clear_session(self) -> None:
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session."""
async with self._locked_connection() as conn:
await conn.execute(
Expand Down
25 changes: 20 additions & 5 deletions src/agents/extensions/memory/dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import json
import random
import time
from typing import Any, Final, Literal
from typing import TYPE_CHECKING, Any, Final, Literal

try:
from dapr.aio.clients import DaprClient
Expand All @@ -42,6 +42,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit

if TYPE_CHECKING:
from ...run_context import RunContextWrapper

# Type alias for consistency levels
ConsistencyLevel = Literal["eventual", "strong"]

Expand Down Expand Up @@ -232,7 +235,10 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) ->
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self, limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -271,7 +277,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self, items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -324,7 +333,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
options=self._get_state_options(),
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down Expand Up @@ -368,7 +380,10 @@ async def pop_item(self) -> TResponseInputItem | None:
except (json.JSONDecodeError, TypeError):
return None

async def clear_session(self) -> None:
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session."""
async with self._lock:
# Delete messages and metadata keys
Expand Down
37 changes: 28 additions & 9 deletions src/agents/extensions/memory/encrypt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import base64
import json
from typing import Any, Literal, TypeGuard, cast
from typing import TYPE_CHECKING, Any, Literal, TypeGuard, cast

from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
Expand All @@ -40,6 +40,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings

if TYPE_CHECKING:
from ...run_context import RunContextWrapper


class EncryptedEnvelope(TypedDict):
"""TypedDict for encrypted message envelopes stored in the underlying session."""
Expand Down Expand Up @@ -170,27 +173,43 @@ def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInpu
except (InvalidToken, KeyError):
return None

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
encrypted_items = await self.underlying_session.get_items(limit)
async def get_items(
self,
limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
encrypted_items = await self.underlying_session.get_items(limit, wrapper=wrapper)
valid_items: list[TResponseInputItem] = []
for enc in encrypted_items:
item = self._unwrap(enc)
if item is not None:
valid_items.append(item)
return valid_items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
await self.underlying_session.add_items(
cast(list[TResponseInputItem], wrapped), wrapper=wrapper
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
while True:
enc = await self.underlying_session.pop_item()
enc = await self.underlying_session.pop_item(wrapper=wrapper)
if not enc:
return None
item = self._unwrap(enc)
if item is not None:
return item

async def clear_session(self) -> None:
await self.underlying_session.clear_session()
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
await self.underlying_session.clear_session(wrapper=wrapper)
25 changes: 20 additions & 5 deletions src/agents/extensions/memory/mongodb_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import json
import threading
import weakref
from typing import Any
from typing import TYPE_CHECKING, Any

try:
from importlib.metadata import version as _get_version
Expand All @@ -57,6 +57,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit

if TYPE_CHECKING:
from ...run_context import RunContextWrapper

# Identifies this library in the MongoDB handshake for server-side telemetry.
_DRIVER_INFO = DriverInfo(name="openai-agents", version=_VERSION)

Expand Down Expand Up @@ -241,7 +244,10 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self, limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -283,7 +289,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self, items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -319,7 +328,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:

await self._messages.insert_many(payload, ordered=True)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand All @@ -340,7 +352,10 @@ async def pop_item(self) -> TResponseInputItem | None:
except (json.JSONDecodeError, KeyError, TypeError):
return None

async def clear_session(self) -> None:
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session."""
await self._ensure_indexes()
await self._messages.delete_many({"session_id": self.session_id})
Expand Down
36 changes: 30 additions & 6 deletions src/agents/extensions/memory/redis_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import asyncio
import json
import time
from typing import Any
from typing import TYPE_CHECKING, Any

try:
import redis.asyncio as redis
Expand All @@ -38,6 +38,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit

if TYPE_CHECKING:
from ...run_context import RunContextWrapper


class RedisSession(SessionABC):
"""Redis implementation of :pyclass:`agents.memory.session.Session`."""
Expand Down Expand Up @@ -140,12 +143,16 @@ async def _set_ttl_if_configured(self, *keys: str) -> None:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self, limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
limit: Maximum number of items to retrieve. If None, uses session_settings.limit.
When specified, returns the latest N items in chronological order.
wrapper: Optional run context wrapper providing context and usage info.

Returns:
List of input items representing the conversation history
Expand Down Expand Up @@ -179,11 +186,15 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self, items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
items: List of input items to add to the history
wrapper: Optional run context wrapper providing context and usage info.
"""
if not items:
return
Expand Down Expand Up @@ -221,9 +232,15 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
self._session_key, self._messages_key, self._counter_key
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Args:
wrapper: Optional run context wrapper providing context and usage info.

Returns:
The most recent item if it exists, None if the session is empty
"""
Expand All @@ -245,8 +262,15 @@ async def pop_item(self) -> TResponseInputItem | None:
# Return None for corrupted messages (already removed)
return None

async def clear_session(self) -> None:
"""Clear all items for this session."""
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session.

Args:
wrapper: Optional run context wrapper providing context and usage info.
"""
async with self._lock:
# Delete all keys associated with this session
await self._redis.delete(
Expand Down
Loading