-
Notifications
You must be signed in to change notification settings - Fork 200
Expand file tree
/
Copy pathsession_memory_compact.py
More file actions
563 lines (447 loc) · 17.7 KB
/
session_memory_compact.py
File metadata and controls
563 lines (447 loc) · 17.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
"""Session-memory-based compaction — uses an on-disk session memory summary
instead of an API call to compact the conversation.
Mirrors the npm ``src/services/compact/sessionMemoryCompact.ts`` module.
When session memory is available and up-to-date, this avoids the cost
of an API round-trip by reusing the background-maintained summary as
the compaction text. Falls back to ``None`` so the caller can use
the legacy API-based compact instead.
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
from .agent_context_usage import estimate_tokens
from .agent_session import AgentMessage
if TYPE_CHECKING:
from .compact import CompactionResult
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class SessionMemoryCompactConfig:
"""Configuration for session-memory compaction thresholds."""
min_tokens: int = 10_000
"""Minimum tokens to preserve after compaction."""
min_text_block_messages: int = 5
"""Minimum number of messages with text content to keep."""
max_tokens: int = 40_000
"""Hard cap — never preserve more than this many tokens."""
max_section_tokens: int = 2_000
"""Maximum tokens per section in the session memory."""
max_total_tokens: int = 12_000
"""Maximum total tokens for the session memory summary."""
DEFAULT_CONFIG = SessionMemoryCompactConfig()
# ---------------------------------------------------------------------------
# Session memory file management
# ---------------------------------------------------------------------------
SESSION_MEMORY_TEMPLATE_SECTIONS = (
'## User Profile',
'## Project Context',
'## Key Decisions & Rationale',
'## Current Task Context',
'## Important Patterns & Preferences',
'## Learned Corrections',
'## Tool Usage Patterns',
'## Conversation Flow',
'## Open Questions & Uncertainties',
)
"""Section headers used in the session memory template."""
def get_session_memory_dir() -> Path:
"""Return the directory where session memory files are stored."""
home = Path.home()
return home / '.claude' / 'session-memory'
def get_session_memory_path() -> Path:
"""Return the path to the session memory file."""
return get_session_memory_dir() / 'session.md'
def load_session_memory() -> str | None:
"""Load session memory from disk, returning None if absent or empty."""
path = get_session_memory_path()
if not path.exists():
return None
try:
content = path.read_text(encoding='utf-8').strip()
if not content:
return None
return content
except (OSError, UnicodeDecodeError):
return None
def save_session_memory(content: str) -> None:
"""Save session memory to disk."""
path = get_session_memory_path()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding='utf-8')
def is_template_only(content: str) -> bool:
"""Check if the session memory is just the empty template.
Returns True if all sections are present but contain no real content
beyond the template headers and italic descriptions.
"""
lines = content.strip().splitlines()
for line in lines:
stripped = line.strip()
if not stripped:
continue
# Skip section headers
if stripped.startswith('##'):
continue
# Skip italic template descriptions
if stripped.startswith('*') and stripped.endswith('*'):
continue
if stripped.startswith('_') and stripped.endswith('_'):
continue
# Found non-template content
return False
return True
# ---------------------------------------------------------------------------
# Message analysis helpers
# ---------------------------------------------------------------------------
def _has_text_content(msg: AgentMessage) -> bool:
"""Check if a message has meaningful text content."""
content = msg.content.strip()
if not content:
return False
if content == '[Old tool result content cleared]':
return False
return True
def _get_tool_result_ids(msg: AgentMessage) -> list[str]:
"""Extract tool_call_ids from a tool-result message."""
if msg.role == 'tool' and msg.tool_call_id:
return [msg.tool_call_id]
return []
def _has_tool_use_with_ids(msg: AgentMessage, ids: set[str]) -> bool:
"""Check if an assistant message contains tool_use blocks matching any of the given IDs."""
if msg.role != 'assistant':
return False
tool_calls = msg.metadata.get('tool_calls') or msg.tool_calls
if not tool_calls:
return False
for tc in tool_calls:
tc_id = tc.get('id', '') if isinstance(tc, dict) else ''
if tc_id in ids:
return True
return False
# ---------------------------------------------------------------------------
# Index calculation
# ---------------------------------------------------------------------------
def adjust_index_to_preserve_api_invariants(
messages: list[AgentMessage],
keep_from: int,
) -> int:
"""Walk backwards to ensure kept messages maintain API-valid structure.
Specifically:
1. All tool_result messages must have corresponding tool_use messages.
2. Assistant messages sharing the same message_id (thinking blocks)
must be kept together.
"""
if keep_from <= 0:
return 0
# Step 1: Ensure tool_use/tool_result pairs
kept_tool_result_ids: set[str] = set()
for msg in messages[keep_from:]:
for tid in _get_tool_result_ids(msg):
kept_tool_result_ids.add(tid)
if kept_tool_result_ids:
idx = keep_from - 1
while idx >= 0:
msg = messages[idx]
if _has_tool_use_with_ids(msg, kept_tool_result_ids):
keep_from = idx
# Include any new tool_results this brings in
for m in messages[idx:keep_from]:
for tid in _get_tool_result_ids(m):
kept_tool_result_ids.add(tid)
idx -= 1
# Step 2: Ensure thinking block continuity (same message_id)
kept_msg_ids: set[str] = set()
for msg in messages[keep_from:]:
if msg.role == 'assistant' and msg.message_id:
kept_msg_ids.add(msg.message_id)
if kept_msg_ids:
idx = keep_from - 1
while idx >= 0:
msg = messages[idx]
if msg.role == 'assistant' and msg.message_id in kept_msg_ids:
keep_from = idx
idx -= 1
return keep_from
def calculate_messages_to_keep_index(
messages: list[AgentMessage],
last_summarized_index: int,
model: str = '',
config: SessionMemoryCompactConfig | None = None,
) -> int:
"""Calculate the index from which to preserve messages.
Starts from ``last_summarized_index + 1`` and expands backwards
to meet the configured minimums (token count, text block count).
Stops if the hard max_tokens cap is reached.
"""
if config is None:
config = DEFAULT_CONFIG
start_index = last_summarized_index + 1
if start_index >= len(messages):
return len(messages)
keep_from = start_index
token_count = 0
text_block_count = 0
# Count forward from keep_from to end
for msg in messages[keep_from:]:
token_count += estimate_tokens(msg.content, model)
if _has_text_content(msg):
text_block_count += 1
# Expand backwards if minimums not met
idx = keep_from - 1
while idx >= 0:
# Check if we've reached a compact boundary — don't go past it
if messages[idx].metadata.get('kind') == 'compact_boundary':
break
msg_tokens = estimate_tokens(messages[idx].content, model)
# Hard cap: stop if adding this would exceed max_tokens
if token_count + msg_tokens > config.max_tokens:
break
# Expand backwards
keep_from = idx
token_count += msg_tokens
if _has_text_content(messages[idx]):
text_block_count += 1
# Check if minimums are met
if (token_count >= config.min_tokens
and text_block_count >= config.min_text_block_messages):
break
idx -= 1
# Ensure API invariants
keep_from = adjust_index_to_preserve_api_invariants(messages, keep_from)
return keep_from
# ---------------------------------------------------------------------------
# Session memory truncation
# ---------------------------------------------------------------------------
def truncate_session_memory(
content: str,
config: SessionMemoryCompactConfig | None = None,
) -> tuple[str, bool]:
"""Truncate session memory sections to fit within token limits.
Returns ``(truncated_content, was_truncated)``.
"""
if config is None:
config = DEFAULT_CONFIG
total_tokens = estimate_tokens(content, '')
if total_tokens <= config.max_total_tokens:
return content, False
# Split by section headers and truncate each
lines = content.splitlines(keepends=True)
sections: list[list[str]] = []
current_section: list[str] = []
for line in lines:
if line.strip().startswith('## ') and current_section:
sections.append(current_section)
current_section = [line]
else:
current_section.append(line)
if current_section:
sections.append(current_section)
truncated_sections: list[str] = []
was_truncated = False
for section in sections:
section_text = ''.join(section)
section_tokens = estimate_tokens(section_text, '')
if section_tokens > config.max_section_tokens:
# Truncate to fit within section limit
truncated_lines: list[str] = []
running_tokens = 0
for line in section:
line_tokens = estimate_tokens(line, '')
if running_tokens + line_tokens > config.max_section_tokens:
truncated_lines.append('...(truncated)\n')
was_truncated = True
break
truncated_lines.append(line)
running_tokens += line_tokens
truncated_sections.append(''.join(truncated_lines))
else:
truncated_sections.append(section_text)
result = ''.join(truncated_sections)
# Check total again
if estimate_tokens(result, '') > config.max_total_tokens:
was_truncated = True
return result, was_truncated
# ---------------------------------------------------------------------------
# Core session memory compaction
# ---------------------------------------------------------------------------
def try_session_memory_compaction(
messages: list[AgentMessage],
model: str = '',
last_summarized_message_id: str | None = None,
auto_compact_threshold: int | None = None,
config: SessionMemoryCompactConfig | None = None,
) -> 'CompactionResult | None':
"""Attempt session-memory-based compaction.
Returns a :class:`CompactionResult` if session memory is available
and the compaction succeeds, or ``None`` to signal the caller should
fall back to API-based compaction.
Parameters
----------
messages:
The current session messages.
model:
Model name for token estimation.
last_summarized_message_id:
The message_id of the last message that was included in the
session memory summary. Messages after this are preserved.
auto_compact_threshold:
If provided, return None if post-compact tokens exceed this.
config:
Compaction configuration thresholds.
"""
from .compact import CompactionResult
if config is None:
config = DEFAULT_CONFIG
# Check environment gates
if os.environ.get('DISABLE_CLAUDE_CODE_SM_COMPACT'):
return None
# Load session memory
session_memory = load_session_memory()
if session_memory is None:
return None
if is_template_only(session_memory):
return None
# Find the boundary message
last_summarized_index: int | None = None
if last_summarized_message_id:
for i, msg in enumerate(messages):
if msg.message_id == last_summarized_message_id:
last_summarized_index = i
break
if last_summarized_index is None:
# No boundary found — can't determine what's already summarized
# Fall back to legacy compact
return None
# Calculate which messages to preserve
keep_from = calculate_messages_to_keep_index(
messages, last_summarized_index, model=model, config=config,
)
messages_to_keep = list(messages[keep_from:])
# Filter out old compact boundaries from kept messages
messages_to_keep = [
m for m in messages_to_keep
if m.metadata.get('kind') != 'compact_boundary'
]
# Truncate session memory if needed
truncated_memory, was_truncated = truncate_session_memory(
session_memory, config=config,
)
# Build the compaction result
pre_tokens = sum(estimate_tokens(m.content, model) for m in messages)
boundary = AgentMessage(
role='user',
content=(
'<system-reminder>\n'
f'Earlier conversation was compacted using session memory. '
f'{len(messages) - len(messages_to_keep)} messages summarized.\n'
'</system-reminder>'
),
message_id='compact_boundary',
metadata={
'kind': 'compact_boundary',
'source': 'session_memory',
'pre_compact_token_count': pre_tokens,
},
)
summary_content = (
'Here is a summary of our conversation so far:\n\n'
f'{truncated_memory}'
)
if was_truncated:
memory_path = get_session_memory_path()
summary_content += (
f'\n\n(Session memory was truncated. '
f'Full version at: {memory_path})'
)
summary_msg = AgentMessage(
role='user',
content=summary_content,
message_id='compact_summary',
metadata={
'kind': 'compact_summary',
'is_compact_summary': True,
'source': 'session_memory',
},
)
post_messages = [boundary, summary_msg] + messages_to_keep
post_tokens = sum(estimate_tokens(m.content, model) for m in post_messages)
# Check threshold
if auto_compact_threshold is not None and post_tokens > auto_compact_threshold:
return None
return CompactionResult(
boundary_message=boundary,
summary_messages=[summary_msg],
messages_to_keep=messages_to_keep,
pre_compact_token_count=pre_tokens,
post_compact_token_count=post_tokens,
true_post_compact_token_count=post_tokens,
summary_text=truncated_memory,
)
# ---------------------------------------------------------------------------
# Session memory extraction (lightweight background summary updater)
# ---------------------------------------------------------------------------
SESSION_MEMORY_EXTRACTION_PROMPT = """Analyze the conversation so far and update the session memory.
Extract key information into these sections:
## User Profile
Who the user is, their role, expertise level, and preferences.
## Project Context
What project/codebase is being worked on, its structure, and tech stack.
## Key Decisions & Rationale
Important decisions made during this session and why.
## Current Task Context
What the user is currently working on and the state of that work.
## Important Patterns & Preferences
Coding style, conventions, or preferences observed.
## Learned Corrections
Mistakes made and corrections applied — things to avoid repeating.
## Tool Usage Patterns
Which tools work well, preferred approaches for common tasks.
## Conversation Flow
Major topic transitions and how the conversation has progressed.
## Open Questions & Uncertainties
Unresolved questions or areas of ambiguity.
Write concisely. Focus on information that would be useful for continuing
this conversation after a context reset. Omit sections with no content."""
def extract_session_memory_from_messages(
messages: list[AgentMessage],
model: str = '',
) -> str:
"""Build a session memory summary from conversation messages.
This is a lightweight local extraction — it walks the messages and
builds a structured summary without an API call. For richer
summaries, the full LLM-based extraction should be used.
"""
user_messages: list[str] = []
tool_names_used: set[str] = set()
file_paths: set[str] = set()
corrections: list[str] = []
for msg in messages:
if msg.role == 'user' and _has_text_content(msg):
user_messages.append(msg.content[:200])
elif msg.role == 'tool' and msg.name:
tool_names_used.add(msg.name)
path = msg.metadata.get('path')
if isinstance(path, str):
file_paths.add(path)
sections: list[str] = []
if user_messages:
sections.append('## Current Task Context')
# Use recent user messages as task context
recent = user_messages[-5:]
for um in recent:
sections.append(f'- {um[:100]}')
if tool_names_used:
sections.append('\n## Tool Usage Patterns')
sections.append(f'Tools used: {", ".join(sorted(tool_names_used))}')
if file_paths:
sections.append('\n## Project Context')
sections.append(f'Files accessed: {", ".join(sorted(list(file_paths)[:20]))}')
sections.append('\n## Conversation Flow')
sections.append(f'Total messages: {len(messages)}')
sections.append(
f'User messages: {sum(1 for m in messages if m.role == "user")}'
)
return '\n'.join(sections)