Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
961 changes: 961 additions & 0 deletions openviking/session/extraction_preprocessor.py

Large diffs are not rendered by default.

176 changes: 162 additions & 14 deletions openviking/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from uuid import uuid4

from openviking.core.namespace import canonical_session_uri
Expand All @@ -19,6 +19,7 @@
from openviking.server.config import ToolOutputExternalizationConfig
from openviking.server.identity import RequestContext, Role
from openviking.session.tool_result_store import ToolResultStore, make_preview, sha256_text
from openviking.session.wm_constants import WM_SEVEN_SECTIONS
from openviking.telemetry import get_current_telemetry, tracer
from openviking.telemetry.request_wait_tracker import get_request_wait_tracker
from openviking.utils.time_utils import get_current_timestamp
Expand All @@ -29,6 +30,7 @@

if TYPE_CHECKING:
from openviking.session.compressor import SessionCompressor
from openviking.session.extraction_preprocessor import PreprocessorOptions
from openviking.storage import VikingDBManager
from openviking.storage.viking_fs import VikingFS

Expand All @@ -55,16 +57,6 @@ def _wm_debug(msg: str) -> None:
# the server do section-level merge against the previous WM.
# =====================================================================

WM_SEVEN_SECTIONS: List[str] = [
"Session Title",
"Current State",
"Task & Goals",
"Key Facts & Decisions",
"Files & Context",
"Errors & Corrections",
"Open Issues",
]

_WM_SECTION_OP_SCHEMA: Dict[str, Any] = {
"oneOf": [
{
Expand Down Expand Up @@ -1922,6 +1914,141 @@ def _generate_archive_summary(
turn_count = len([m for m in messages if m.role == "user"])
return f"# Session Summary\n\n**Overview**: {turn_count} turns, {len(messages)} messages"

@staticmethod
def _build_wm_preprocessor_options(memory_config: Any) -> "PreprocessorOptions":
from openviking.session.extraction_preprocessor import PreprocessorOptions

return PreprocessorOptions(
max_span_tokens=int(getattr(memory_config, "wm_v2_preprocess_max_span_tokens", 1200)),
min_span_tokens=int(getattr(memory_config, "wm_v2_preprocess_min_span_tokens", 200)),
max_span_chars=int(getattr(memory_config, "wm_v2_preprocess_max_span_chars", 1600)),
fallback_if_compact_ratio_above=float(
getattr(memory_config, "wm_v2_preprocess_fallback_ratio", 0.9)
),
expand_budget_on_risk=bool(
getattr(memory_config, "wm_v2_preprocess_expand_budget_on_risk", True)
),
max_facts_total=int(getattr(memory_config, "wm_v2_preprocess_max_facts_total", 24)),
min_full_tokens_for_compact=int(
getattr(memory_config, "wm_v2_preprocess_min_full_tokens", 600)
),
min_absolute_savings_tokens=int(
getattr(memory_config, "wm_v2_preprocess_min_absolute_savings_tokens", 500)
),
mmr_similarity_threshold=float(
getattr(memory_config, "wm_v2_preprocess_mmr_similarity_threshold", 0.72)
),
max_tool_output_chars=int(
getattr(memory_config, "wm_v2_preprocess_max_tool_output_chars", 300)
),
max_tool_spans=int(getattr(memory_config, "wm_v2_preprocess_max_tool_spans", 3)),
)

@staticmethod
def _record_wm_preprocessor_telemetry(
phase: str,
enabled: bool,
packet: Optional[Any] = None,
*,
exception: bool = False,
) -> None:
telemetry = get_current_telemetry()
telemetry.set("wm.preprocess.phase", phase)
telemetry.set("wm.preprocess.enabled", enabled)
if not enabled:
return
if exception:
telemetry.set("wm.preprocess.fallback_reason", "exception")
return
if packet is None:
return
telemetry.set(
"wm.preprocess.full_messages_tokens_est",
packet.token_estimates.full_messages_tokens_est,
)
telemetry.set(
"wm.preprocess.compact_packet_tokens_est",
packet.token_estimates.compact_packet_tokens_est,
)
telemetry.set(
"wm.preprocess.saved_tokens_est",
packet.token_estimates.saved_tokens_est,
)
telemetry.set("wm.preprocess.selected_spans_count", len(packet.selected_spans))
telemetry.set("wm.preprocess.structured_facts_count", len(packet.structured_facts))
telemetry.set("wm.preprocess.risk_flags", list(packet.risk_flags))
telemetry.set("wm.preprocess.fallback_reason", packet.fallback_reason or "")

@classmethod
def _log_wm_preprocessor_result(cls, phase: str, packet: Any) -> None:
phase_upper = phase.upper()
if packet.should_fallback:
logger.info(
"WM_PREPROCESS %s FALLBACK: reason=%s full=%d compact=%d "
"saved=%d spans=%d facts=%d risk=%s",
phase_upper,
packet.fallback_reason,
packet.token_estimates.full_messages_tokens_est,
packet.token_estimates.compact_packet_tokens_est,
packet.token_estimates.saved_tokens_est,
len(packet.selected_spans),
len(packet.structured_facts),
packet.risk_flags,
)
return
logger.info(
"WM_PREPROCESS %s ACTIVE: full=%d compact=%d saved=%d (%d%%) spans=%d facts=%d risk=%s",
phase_upper,
packet.token_estimates.full_messages_tokens_est,
packet.token_estimates.compact_packet_tokens_est,
packet.token_estimates.saved_tokens_est,
int(
packet.token_estimates.saved_tokens_est
* 100
/ max(packet.token_estimates.full_messages_tokens_est, 1)
),
len(packet.selected_spans),
len(packet.structured_facts),
packet.risk_flags,
)

@classmethod
def _run_wm_preprocessor(
cls,
*,
messages: List[Message],
latest_archive_overview: str,
formatted_messages: str,
memory_config: Any,
phase: str,
) -> Tuple[str, Optional[Any]]:
enabled = bool(getattr(memory_config, "wm_v2_preprocess_enabled", False))
cls._record_wm_preprocessor_telemetry(phase, enabled)
if not enabled:
return formatted_messages, None

try:
from openviking.session.extraction_preprocessor import build_wm_compact_packet

packet = build_wm_compact_packet(
messages,
latest_overview=latest_archive_overview,
options=cls._build_wm_preprocessor_options(memory_config),
)
cls._record_wm_preprocessor_telemetry(phase, enabled, packet)
cls._log_wm_preprocessor_result(phase, packet)
if packet.should_fallback:
return formatted_messages, packet
return packet.wm_update_view, packet
except Exception as e:
cls._record_wm_preprocessor_telemetry(phase, enabled, exception=True)
logger.warning(
"WM %s preprocessing failed (%s); using full messages",
phase,
e,
)
return formatted_messages, None

async def _generate_archive_summary_async(
self,
messages: List[Message],
Expand All @@ -1948,7 +2075,8 @@ async def _generate_archive_summary_async(

formatted = "\n".join(self._format_message_for_wm(m) for m in messages)

vlm = get_openviking_config().vlm
config = get_openviking_config()
vlm = config.vlm
if not (vlm and vlm.is_available()):
turn_count = len([m for m in messages if m.role == "user"])
return (
Expand Down Expand Up @@ -1976,10 +2104,22 @@ async def _generate_archive_summary_async(
f"{len(latest_archive_overview or '')}B)"
)
try:
memory_config = getattr(config, "memory", None)
creation_messages, _packet = self._run_wm_preprocessor(
messages=messages,
latest_archive_overview=latest_archive_overview or "",
formatted_messages=formatted,
memory_config=memory_config,
phase="creation",
)

# The WM creation/update prompts must tolerate two `messages`
# formats: the raw `_format_message_for_wm()` fallback shape
# and the compact packet shape emitted by the preprocessor.
prompt = render_prompt(
"compression.ov_wm_v2",
{
"messages": formatted,
"messages": creation_messages,
"latest_archive_overview": latest_archive_overview or "",
},
)
Expand All @@ -1996,13 +2136,21 @@ async def _generate_archive_summary_async(
# -------- Branch 2: has prior WM v2 -> tool_call incremental update --------
_wm_debug(f"branch=UPDATE (prior WM={len(latest_archive_overview)}B)")
try:
memory_config = getattr(config, "memory", None)
update_messages, _packet = self._run_wm_preprocessor(
messages=messages,
latest_archive_overview=latest_archive_overview,
formatted_messages=formatted,
memory_config=memory_config,
phase="update",
)
reminders = Session._build_wm_section_reminders(latest_archive_overview)
if reminders:
_wm_debug(f"section_reminders injected ({len(reminders)}B)")
update_prompt = render_prompt(
"compression.ov_wm_v2_update",
{
"messages": formatted,
"messages": update_messages,
"latest_archive_overview": latest_archive_overview,
"wm_section_reminders": reminders,
},
Expand Down
15 changes: 15 additions & 0 deletions openviking/session/wm_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: AGPL-3.0
"""Shared Working Memory v2 constants."""

from typing import List

WM_SEVEN_SECTIONS: List[str] = [
"Session Title",
"Current State",
"Task & Goals",
"Key Facts & Decisions",
"Files & Context",
"Errors & Corrections",
"Open Issues",
]
5 changes: 5 additions & 0 deletions openviking/utils/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def _estimate_embedding_input_tokens(text: str) -> int:
return max(1, cjk_chars + math.ceil(other_chars / 4))


def estimate_embedding_input_tokens(text: str) -> int:
"""Public alias for the repository's CJK-aware token estimate."""
return _estimate_embedding_input_tokens(text)


def _truncate_embedding_input(
text: str,
max_tokens: int,
Expand Down
81 changes: 81 additions & 0 deletions openviking_cli/utils/config/memory_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,87 @@ class MemoryConfig(BaseModel):
"is ignored and the login user from the request context is used instead."
),
)
wm_v2_preprocess_enabled: bool = Field(
default=False,
description=(
"Enable compact pre-processing for Working Memory v2 incremental update prompts. "
"When disabled, WM v2 update uses the original full archived messages."
),
)
wm_v2_preprocess_max_span_tokens: int = Field(
default=1200,
ge=100,
description="Maximum estimated tokens to spend on selected evidence spans.",
)
wm_v2_preprocess_min_span_tokens: int = Field(
default=200,
ge=0,
description="Minimum span budget floor after adaptive preprocessing adjustments.",
)
wm_v2_preprocess_max_span_chars: int = Field(
default=1600,
ge=100,
description="Maximum characters allowed in each selected evidence span.",
)
wm_v2_preprocess_fallback_ratio: float = Field(
default=0.9,
ge=0.1,
le=10.0,
description=(
"Fallback to full messages when compact packet tokens are greater than this "
"ratio of full message tokens."
),
)
wm_v2_preprocess_min_full_tokens: int = Field(
default=600,
ge=0,
description=(
"Skip compact preprocessing when the estimated full message tokens are "
"below this threshold. Set to 0 to force compact even for short sessions."
),
)
wm_v2_preprocess_min_absolute_savings_tokens: int = Field(
default=500,
ge=0,
description=(
"Fallback to full messages when compact preprocessing saves fewer than "
"this many estimated tokens, even if the ratio threshold passes."
),
)
wm_v2_preprocess_mmr_similarity_threshold: float = Field(
default=0.72,
ge=0.0,
le=1.0,
description=(
"Maximum Jaccard similarity allowed between selected non-tool evidence "
"spans before they are considered redundant."
),
)
wm_v2_preprocess_max_tool_spans: int = Field(
default=3,
ge=0,
description=(
"Maximum number of tool-heavy spans that can bypass normal MMR "
"deduplication in a compact packet."
),
)
wm_v2_preprocess_expand_budget_on_risk: bool = Field(
default=True,
description=(
"When enabled, risk flags can expand the evidence span budget before "
"compaction fallback is decided."
),
)
wm_v2_preprocess_max_facts_total: int = Field(
default=24,
ge=0,
description="Maximum structured facts retained in a compact packet.",
)
wm_v2_preprocess_max_tool_output_chars: int = Field(
default=300,
ge=0,
description="Maximum characters preserved from each tool output in normalized spans.",
)
link_enabled: bool = Field(
default=False,
description=(
Expand Down
Loading
Loading