Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions openviking/session/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,22 @@ async def _merge_into_existing(
viking_fs,
ctx: RequestContext,
) -> bool:
"""Merge candidate content into an existing memory file."""
"""Merge candidate content into an existing memory file.

Uses optimistic concurrency control (OCC): reads the file's modTime
before the expensive LLM merge call, then verifies it hasn't changed
before writing. If a concurrent modification is detected, the merge
is aborted and the caller treats it as skipped — the next commit
cycle will re-dedup and re-attempt.
"""
try:
stat_info = await viking_fs.stat(target_memory.uri, ctx=ctx)
read_mod_time = (
str(stat_info.get("modTime", stat_info.get("mtime", "")))
if isinstance(stat_info, dict)
else ""
)

existing_content = await viking_fs.read_file(target_memory.uri, ctx=ctx)
payload = await self.extractor._merge_memory_bundle(
existing_abstract=target_memory.abstract,
Expand All @@ -269,7 +283,19 @@ async def _merge_into_existing(
if not payload:
return False

await viking_fs.write_file(target_memory.uri, payload.content, ctx=ctx)
if read_mod_time:
written = await viking_fs.cas_write_file(
target_memory.uri, payload.content, read_mod_time, ctx=ctx
)
if not written:
logger.warning(
"OCC: concurrent modification on %s, skipping merge",
target_memory.uri,
)
return False
else:
await viking_fs.write_file(target_memory.uri, payload.content, ctx=ctx)

target_memory.abstract = payload.abstract
target_memory.meta = {**(target_memory.meta or {}), "overview": payload.overview}
logger.info(
Expand All @@ -285,8 +311,8 @@ async def _merge_into_existing(
# Clean up vector record for the missing file so it's not retried
try:
await self.vikingdb.delete_uris(ctx, [target_memory.uri])
except Exception:
pass
except Exception as e:
logger.warning("Failed to clean up vector record for %s: %s", target_memory.uri, e)
return False
except Exception as e:
logger.error(f"Failed to merge memory {target_memory.uri}: {e}")
Expand All @@ -295,7 +321,22 @@ async def _merge_into_existing(
async def _delete_existing_memory(
self, memory: Context, viking_fs, ctx: RequestContext
) -> bool:
"""Hard delete an existing memory file and clean up its vector record."""
"""Hard delete an existing memory file and clean up its vector record.

Enforces a dedup score guardrail: refuses to delete memories whose
_dedup_score is below 0.5, preventing LLM-hallucinated DELETE decisions
on low-similarity matches.
"""
dedup_score = (memory.meta or {}).get("_dedup_score", 0)
if isinstance(dedup_score, (int, float)) and dedup_score < 0.5:
logger.warning(
"Refusing to delete memory %s with low dedup_score=%.4f "
"(floor=0.50). LLM may have hallucinated the DELETE decision.",
memory.uri,
float(dedup_score),
)
return False

try:
await viking_fs.rm(memory.uri, recursive=False, ctx=ctx)
except Exception as e:
Expand Down
65 changes: 59 additions & 6 deletions openviking/session/memory_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,26 @@ async def _append_to_profile(
viking_fs,
ctx: RequestContext,
) -> Optional[MergedMemoryPayload]:
"""Update user profile - always merge with existing content."""
"""Update user profile - always merge with existing content.

Uses optimistic concurrency control (OCC) when merging into an
existing profile to prevent lost updates from concurrent sessions.
"""
uri = f"{canonical_user_root(ctx)}/memories/profile.md"
existing = ""
read_mod_time = ""
try:
stat_info = await viking_fs.stat(uri, ctx=ctx)
read_mod_time = (
str(stat_info.get("modTime", stat_info.get("mtime", "")))
if isinstance(stat_info, dict)
else ""
)
existing = await viking_fs.read_file(uri, ctx=ctx) or ""
except Exception:
pass
except FileNotFoundError:
logger.debug("Profile %s does not exist yet, will create", uri)
except Exception as e:
logger.warning("Failed to stat/read profile %s: %s", uri, e)

if not existing.strip():
await viking_fs.write_file(uri=uri, content=candidate.content, ctx=ctx)
Expand All @@ -557,7 +570,17 @@ async def _append_to_profile(
if not payload:
logger.warning("Profile merge bundle failed; keeping existing profile unchanged")
return None
await viking_fs.write_file(uri=uri, content=payload.content, ctx=ctx)
if read_mod_time:
written = await viking_fs.cas_write_file(
uri, payload.content, read_mod_time, ctx=ctx
)
if not written:
logger.warning(
"OCC: concurrent modification on profile %s, skipping write", uri
)
return None
else:
await viking_fs.write_file(uri=uri, content=payload.content, ctx=ctx)
logger.info(f"Merged profile info to {uri}")
return payload

Expand Down Expand Up @@ -644,7 +667,14 @@ async def _merge_tool_memory(
return None

existing = ""
read_mod_time = ""
try:
stat_info = await viking_fs.stat(uri, ctx=ctx)
read_mod_time = (
str(stat_info.get("modTime", stat_info.get("mtime", "")))
if isinstance(stat_info, dict)
else ""
)
existing = await viking_fs.read_file(uri, ctx=ctx) or ""
except NotFoundError:
existing = ""
Expand Down Expand Up @@ -749,7 +779,15 @@ async def _merge_tool_memory(
merged_content = self._generate_tool_memory_content(
tool_name, merged_stats, merged_guidelines, fields=merged_fields
)
await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx)
if read_mod_time:
written = await viking_fs.cas_write_file(uri, merged_content, read_mod_time, ctx=ctx)
if not written:
logger.warning(
"OCC: concurrent modification on tool memory %s, skipping write", uri
)
return None
else:
await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx)
return self._create_tool_context(uri, candidate, ctx, abstract_override=abstract_override)

def _compute_statistics_derived(self, stats: dict) -> dict:
Expand Down Expand Up @@ -1198,7 +1236,14 @@ async def _merge_skill_memory(
return None

existing = ""
read_mod_time = ""
try:
stat_info = await viking_fs.stat(uri, ctx=ctx)
read_mod_time = (
str(stat_info.get("modTime", stat_info.get("mtime", "")))
if isinstance(stat_info, dict)
else ""
)
existing = await viking_fs.read_file(uri, ctx=ctx) or ""
except NotFoundError:
existing = ""
Expand Down Expand Up @@ -1320,7 +1365,15 @@ async def _merge_skill_memory(
merged_content = self._generate_skill_memory_content(
skill_name, merged_stats, merged_guidelines, fields=merged_fields
)
await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx)
if read_mod_time:
written = await viking_fs.cas_write_file(uri, merged_content, read_mod_time, ctx=ctx)
if not written:
logger.warning(
"OCC: concurrent modification on skill memory %s, skipping write", uri
)
return None
else:
await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx)
return self._create_skill_context(uri, candidate, ctx, abstract_override=abstract_override)

def _compute_skill_statistics_derived(self, stats: dict) -> dict:
Expand Down
46 changes: 46 additions & 0 deletions openviking/storage/viking_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,52 @@ async def write_file(
content = await self._encrypt_content(content, ctx=ctx)
await self._run_in_threadpool(self.agfs.write, path, content)

async def cas_write_file(
self,
uri: str,
content: Union[str, bytes],
expected_mod_time: str,
ctx: Optional[RequestContext] = None,
) -> bool:
"""Compare-and-swap write: only succeeds if the file's modTime matches.

Returns True if the write succeeded, False if a concurrent modification
was detected (modTime changed since the read).
"""
self._ensure_access(uri, ctx)
path = self._uri_to_path(uri, ctx=ctx)

try:
stat = await self._run_in_threadpool(self.agfs.stat, path)
except FileNotFoundError:
logger.debug("OCC stat failed for %s: file not found", uri)
return False
except Exception as e:
logger.warning("OCC stat failed for %s: %s", uri, e)
return False

current_mod_time = ""
if isinstance(stat, dict):
current_mod_time = str(stat.get("modTime", stat.get("mtime", "")))

if current_mod_time != expected_mod_time:
logger.warning(
"OCC conflict on %s: expected modTime=%s, current=%s",
uri,
expected_mod_time,
current_mod_time,
)
return False

await self._ensure_parent_dirs(path)

if isinstance(content, str):
content = content.encode("utf-8")

content = await self._encrypt_content(content, ctx=ctx)
await self._run_in_threadpool(self.agfs.write, path, content)
return True

async def read_file(
self,
uri: str,
Expand Down
145 changes: 145 additions & 0 deletions tests/session/test_memory_occ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: AGPL-3.0

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

from openviking.core.context import Context
from openviking.server.identity import RequestContext, Role
from openviking.session.compressor import SessionCompressor
from openviking.session.memory_extractor import (
CandidateMemory,
MemoryCategory,
MergedMemoryPayload,
)
from openviking_cli.session.user_id import UserIdentifier


def _make_user() -> UserIdentifier:
return UserIdentifier("acc1", "test_user", "test_agent")


def _make_ctx() -> RequestContext:
return RequestContext(user=_make_user(), role=Role.USER)


def _make_candidate() -> CandidateMemory:
return CandidateMemory(
category=MemoryCategory.PREFERENCES,
abstract="test abstract",
overview="test overview",
content="test content",
source_session="s1",
user=_make_user(),
language="en",
)


def _make_memory(meta=None) -> Context:
m = Context(
uri="viking://user/test_user/memories/preferences/existing.md",
parent_uri="viking://user/test_user/memories/preferences",
is_leaf=True,
abstract="existing",
context_type="memory",
category="preferences",
)
if meta:
m.meta = meta
return m


def _make_compressor() -> SessionCompressor:
with patch("openviking.session.memory_deduplicator.get_openviking_config") as cfg:
cfg.return_value.embedding.get_embedder.return_value = None
return SessionCompressor(vikingdb=MagicMock())


def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)


class TestDeleteGuardrail:
def test_low_dedup_score_blocks_delete(self):
memory = _make_memory(meta={"_dedup_score": 0.3})
fs = MagicMock()
fs.rm = AsyncMock()
assert _run(_make_compressor()._delete_existing_memory(memory, fs, _make_ctx())) is False
fs.rm.assert_not_called()

def test_high_dedup_score_allows_delete(self):
memory = _make_memory(meta={"_dedup_score": 0.8})
fs = MagicMock()
fs.rm = AsyncMock()
assert _run(_make_compressor()._delete_existing_memory(memory, fs, _make_ctx())) is True
fs.rm.assert_called_once()


class TestMergeOCC:
def test_merge_aborted_on_concurrent_modification(self):
compressor = _make_compressor()
fs = MagicMock()
fs.stat = AsyncMock(return_value={"modTime": "T1"})
fs.read_file = AsyncMock(return_value="old")
fs.cas_write_file = AsyncMock(return_value=False)

with patch.object(
compressor.extractor,
"_merge_memory_bundle",
new=AsyncMock(
return_value=MergedMemoryPayload(
abstract="a", overview="o", content="c", reason="r"
)
),
):
with patch.object(compressor, "_index_memory", new=AsyncMock()):
assert (
_run(
compressor._merge_into_existing(
_make_candidate(), _make_memory(), fs, _make_ctx()
)
)
is False
)

def test_merge_succeeds_when_no_concurrent_write(self):
compressor = _make_compressor()
fs = MagicMock()
fs.stat = AsyncMock(return_value={"modTime": "T1"})
fs.read_file = AsyncMock(return_value="old")
fs.cas_write_file = AsyncMock(return_value=True)

with patch.object(
compressor.extractor,
"_merge_memory_bundle",
new=AsyncMock(
return_value=MergedMemoryPayload(
abstract="a", overview="o", content="c", reason="r"
)
),
):
with patch.object(compressor, "_index_memory", new=AsyncMock()):
assert (
_run(
compressor._merge_into_existing(
_make_candidate(), _make_memory(), fs, _make_ctx()
)
)
is True
)
fs.cas_write_file.assert_called_once()


class TestCasWriteFile:
def test_cas_write_blocks_on_modtime_mismatch(self):
from openviking.storage.viking_fs import VikingFS

fs = VikingFS.__new__(VikingFS)
fs._agfs = MagicMock()
fs._agfs.stat = MagicMock(return_value={"modTime": "T2"})
fs._agfs.write = MagicMock()
fs._ensure_access = MagicMock()
fs._run_in_threadpool = AsyncMock(side_effect=lambda fn, *a: fn(*a))

assert _run(fs.cas_write_file("viking://test/f.md", "c", "T1", ctx=_make_ctx())) is False
fs._agfs.write.assert_not_called()
Loading