Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 78 additions & 62 deletions src/praisonai-agents/praisonaiagents/session/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_HAS_FCNTL = False
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from ..paths import get_sessions_dir

Expand Down Expand Up @@ -291,6 +291,63 @@ def _load_session(self, session_id: str) -> SessionData:
self._cache[session_id] = session

return session

def _load_session_from_disk(self, session_id: str, filepath: str) -> SessionData:
"""Load session JSON from disk (caller must hold FileLock)."""
if os.path.exists(filepath):
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return SessionData.from_dict(data)
except (json.JSONDecodeError, IOError):
pass
return SessionData(session_id=session_id)
Comment on lines +295 to +304
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _load_session_from_disk silently swallows json.JSONDecodeError and IOError without any log output, whereas the original _load_session logs a warning. When a session file is corrupt or unreadable, the error will be invisible, making it very hard to diagnose data-loss incidents in production.

Suggested change
def _load_session_from_disk(self, session_id: str, filepath: str) -> SessionData:
"""Load session JSON from disk (caller must hold FileLock)."""
if os.path.exists(filepath):
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return SessionData.from_dict(data)
except (json.JSONDecodeError, IOError):
pass
return SessionData(session_id=session_id)
def _load_session_from_disk(self, session_id: str, filepath: str) -> SessionData:
"""Load session JSON from disk (caller must hold FileLock)."""
if os.path.exists(filepath):
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return SessionData.from_dict(data)
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load session {session_id}: {e}")
return SessionData(session_id=session_id)


def _modify_session_locked(
self,
session_id: str,
mutator: Callable[[SessionData], None],
*,
error_label: str = "modify session",
) -> bool:
"""Apply mutator after reloading from disk under FileLock."""
filepath = self._get_session_path(session_id)

with FileLock(filepath, self.lock_timeout):
session = self._load_session_from_disk(session_id, filepath)
mutator(session)
session.updated_at = datetime.now(timezone.utc).isoformat()

if len(session.messages) > self.max_messages:
session.messages = session.messages[-self.max_messages:]

try:
dir_path = os.path.dirname(filepath) or "."
os.makedirs(dir_path, exist_ok=True)
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=dir_path,
delete=False,
suffix=".tmp",
) as f:
json.dump(session.to_dict(), f, indent=2, ensure_ascii=False)
temp_path = f.name

os.replace(temp_path, filepath)

with self._lock:
self._cache[session_id] = session

return True
except (IOError, OSError) as e:
logger.error(f"Failed to {error_label} {session_id}: {e}")
try:
if "temp_path" in locals():
os.remove(temp_path)
except (IOError, OSError):
pass
return False

def _save_session(self, session: SessionData) -> bool:
"""Save session to disk with atomic write."""
Expand Down Expand Up @@ -455,81 +512,41 @@ def set_agent_info(
user_id: Optional[str] = None,
) -> bool:
"""Set agent info for a session."""
session = self._load_session(session_id)

with self._lock:

def _apply(session: SessionData) -> None:
if agent_name:
session.agent_name = agent_name
if user_id:
session.user_id = user_id
self._cache[session_id] = session

return self._save_session(session)

return self._modify_session_locked(
session_id, _apply, error_label="set agent info for session"
)

def clear_session(self, session_id: str) -> bool:
"""Clear all messages from a session."""
session = self._load_session(session_id)

with self._lock:
session.messages.clear()
self._cache[session_id] = session

return self._save_session(session)
return self._modify_session_locked(
session_id,
lambda session: session.messages.clear(),
error_label="clear session",
)

def update_session_metadata(self, session_id: str, **fields: Any) -> bool:
"""Merge run stats / metadata fields into a persisted session."""
if not fields:
return True

filepath = self._get_session_path(session_id)

with FileLock(filepath, self.lock_timeout):
if os.path.exists(filepath):
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
session = SessionData.from_dict(data)
except (json.JSONDecodeError, IOError):
session = SessionData(session_id=session_id)
else:
session = SessionData(session_id=session_id)

def _apply(session: SessionData) -> None:
for key, value in fields.items():
if value is None:
continue
session.metadata[key] = value
if key in ("agent_id", "agent_name", "user_id"):
setattr(session, key, value)

session.updated_at = datetime.now(timezone.utc).isoformat()

try:
dir_path = os.path.dirname(filepath) or "."
os.makedirs(dir_path, exist_ok=True)
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=dir_path,
delete=False,
suffix=".tmp",
) as f:
json.dump(session.to_dict(), f, indent=2, ensure_ascii=False)
temp_path = f.name

os.replace(temp_path, filepath)

with self._lock:
self._cache[session_id] = session

return True
except (IOError, OSError) as e:
logger.error(f"Failed to update session metadata {session_id}: {e}")
try:
if "temp_path" in locals():
os.remove(temp_path)
except (IOError, OSError):
pass
return False
return self._modify_session_locked(
session_id, _apply, error_label="update session metadata"
)

def delete_session(self, session_id: str) -> bool:
"""Delete a session completely."""
Expand Down Expand Up @@ -683,16 +700,15 @@ def set_gateway_info(
Returns:
True if saved successfully
"""
session = self._load_session(session_id)

with self._lock:
def _apply(session: SessionData) -> None:
if gateway_session_id:
session.gateway_session_id = gateway_session_id
if agent_id:
session.agent_id = agent_id
self._cache[session_id] = session

return self._save_session(session)

return self._modify_session_locked(
session_id, _apply, error_label="set gateway info for session"
)

def get_by_gateway_session(self, gateway_session_id: str) -> Optional[SessionData]:
"""Get session data linked to a gateway session.
Expand Down
58 changes: 58 additions & 0 deletions src/praisonai-agents/tests/unit/session/test_session_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,64 @@ def test_update_session_metadata_preserves_messages(self, temp_store):
session = writer.get_session("session-1")
assert session.metadata.get("model") == "gpt-4o-mini"
assert session.metadata.get("total_tokens") == 42

def test_set_agent_info_preserves_messages(self, temp_store):
"""Agent info updates must not drop messages added by another store instance."""
with tempfile.TemporaryDirectory() as tmpdir:
writer = DefaultSessionStore(session_dir=tmpdir)
reader = DefaultSessionStore(session_dir=tmpdir)

writer.add_user_message("session-1", "first")
reader._load_session("session-1")
writer.add_user_message("session-1", "second")

assert reader.set_agent_info("session-1", agent_name="TestBot", user_id="u-1")

writer.invalidate_cache("session-1")
history = writer.get_chat_history("session-1")
assert len(history) == 2
assert history[1]["content"] == "second"

session = writer.get_session("session-1")
assert session.agent_name == "TestBot"
assert session.user_id == "u-1"

def test_clear_session_preserves_new_messages(self, temp_store):
"""Clear must reload from disk so concurrent adds are not lost."""
Comment on lines +330 to +331
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The test name and docstring say "preserves new messages," but the assertion checks len(history) == 0 — i.e., that all messages were cleared. The test is validating correct clear_session behaviour (reload-then-clear), not message preservation. The misleading name could cause a future reader to misinterpret the intent and write a wrong expectation.

Suggested change
def test_clear_session_preserves_new_messages(self, temp_store):
"""Clear must reload from disk so concurrent adds are not lost."""
def test_clear_session_reads_latest_disk_state(self, temp_store):
"""Clear must reload from disk so it operates on the latest persisted state."""

with tempfile.TemporaryDirectory() as tmpdir:
writer = DefaultSessionStore(session_dir=tmpdir)
reader = DefaultSessionStore(session_dir=tmpdir)

writer.add_user_message("session-1", "first")
reader._load_session("session-1")
writer.add_user_message("session-1", "second")

assert reader.clear_session("session-1")

writer.invalidate_cache("session-1")
history = writer.get_chat_history("session-1")
assert len(history) == 0

def test_set_gateway_info_preserves_messages(self, temp_store):
"""Gateway info updates must not drop messages added by another store instance."""
with tempfile.TemporaryDirectory() as tmpdir:
writer = DefaultSessionStore(session_dir=tmpdir)
reader = DefaultSessionStore(session_dir=tmpdir)

writer.add_user_message("session-1", "first")
reader._load_session("session-1")
writer.add_user_message("session-1", "second")

assert reader.set_gateway_info("session-1", gateway_session_id="gw-123", agent_id="agent-456")

writer.invalidate_cache("session-1")
history = writer.get_chat_history("session-1")
assert len(history) == 2
assert history[1]["content"] == "second"

session = writer.get_session("session-1")
assert session.gateway_session_id == "gw-123"
assert session.agent_id == "agent-456"

def test_concurrent_writes(self, temp_store):
"""Test concurrent writes to same session."""
Expand Down
Loading