feat: true O(1) stateful loop engine using MambaCache recurrence#1
feat: true O(1) stateful loop engine using MambaCache recurrence#1ItsMick wants to merge 1 commit intobatteryphil:mainfrom
Conversation
Replace the re-tokenize-per-loop approach with single-token recurrent steps via HuggingFace MambaCache. The existing latent loop rebuilds SSM state from scratch every iteration (O(n) per loop, sequence grows). The new engine prefills once, then feeds single spacer tokens while passing cache state forward — O(1) per iteration, constant memory. Key changes: - stateful_engine.py: StatefulLoopEngine with cache-based iteration - session_memory.py: latent_turn() upgraded to O(1) cache steps - Benchmark: 2.35x speedup on 2.8B, 3.17x on 130M (CPU, no CUDA kernels) - Correctness: prefill hidden states match exactly (cosine sim = 1.0) - API finding: Mamba uses cache_params + cache_position, not past_key_values mamba_engine.py is unchanged — original training engine preserved. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces a new “stateful” latent-loop execution path for Mamba models that uses recurrent single-token steps with MambaCache (instead of re-tokenizing an expanding prompt each loop) to achieve O(1) per-iteration cost, and updates session memory to use the same approach. It also adds scripts and docs to validate correctness and benchmark LLPS.
Changes:
- Add
StatefulLoopEngineimplementing cache-based single-token loop steps and cache-backed generation. - Update
session_memory.latent_turn()to use prefill + recurrent cache steps (no expanding prompt). - Add validation/benchmark scripts and documentation capturing cache API findings and performance results.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
stateful_engine.py |
New stateful loop engine built around MambaCache recurrent stepping + generation. |
session_memory.py |
Switches latent turn loop to cache-based O(1) stepping and cache-backed generation. |
validate_stateful.py |
New correctness comparison script (hidden-state trace + generation smoke test). |
benchmark_llps.py |
New LLPS benchmark comparing stateless re-tokenize vs stateful cache stepping. |
docs/cache_api_findings.md |
Notes on the discovered Mamba cache API differences and required cache_position. |
docs/correctness_validation.md |
Captures validation results and observed divergence after loop 0. |
docs/llps_benchmark.md |
Captures benchmark results and analysis. |
docs/blockers.md |
Tracks environment blockers and “kill switch” outcomes. |
CONTRIBUTION.md |
Narrative description of the architectural change, results, and remaining work. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| seq_len = toks["input_ids"].shape[1] | ||
| out = mdl( | ||
| **toks, | ||
| cache_params=cache, |
There was a problem hiding this comment.
In the prefill call you pass an existing cache_params=cache but do not provide cache_position. Per the documented Mamba API, providing cache_params without cache_position raises a ValueError, so latent_turn() will fail at runtime. Provide a prefill-mode cache_position (shape == conv kernel size) or avoid passing cache_params during prefill and only start passing it on the single-token decode steps.
| seq_len = toks["input_ids"].shape[1] | |
| out = mdl( | |
| **toks, | |
| cache_params=cache, | |
| seq_len = toks["input_ids"].shape[1] | |
| conv_kernel = mdl.config.conv_kernel | |
| prefill_start = max(seq_len - conv_kernel, 0) | |
| prefill_cache_pos = torch.arange( | |
| prefill_start, | |
| prefill_start + conv_kernel, | |
| device="cuda" | |
| ) | |
| out = mdl( | |
| **toks, | |
| cache_params=cache, | |
| cache_position=prefill_cache_pos, |
| for lp in range(MAX_LOOPS): | ||
| text = prompt + "=" * lp | ||
| toks = tok(text, return_tensors="pt", | ||
| truncation=True, max_length=512).to("cuda") | ||
| # Pass cache_params so the SSM state is updated in-place | ||
| out = mdl( | ||
| **toks, | ||
| cache_params=cache, | ||
| use_cache=True, | ||
| output_hidden_states=True | ||
| ) | ||
| h = out.hidden_states[-1][0, -1, :].float() | ||
| ln = torch.tensor([lp / m], dtype=torch.float32, device="cuda") | ||
| p = head(torch.cat([h, ln]).unsqueeze(0)).item() | ||
| if p >= HALT_THRESH: | ||
| break | ||
|
|
||
| # Autoregressive surface generation from updated cache state | ||
| cache_pos = torch.tensor([seq_len + lp], device="cuda") | ||
| step_out = mdl( | ||
| input_ids=spacer, | ||
| cache_params=cache, | ||
| cache_position=cache_pos, | ||
| use_cache=True, | ||
| output_hidden_states=True | ||
| ) | ||
| h = step_out.hidden_states[-1][0, -1, :].float() | ||
|
|
||
| # Autoregressive surface generation from accumulated cache state | ||
| gen_cache_pos = torch.tensor([seq_len + lp + 1], device="cuda") | ||
| gen_out = mdl.generate( | ||
| toks["input_ids"], | ||
| spacer, | ||
| cache_params=cache, | ||
| cache_position=gen_cache_pos, | ||
| max_new_tokens=120, | ||
| do_sample=False, | ||
| repetition_penalty=1.1, | ||
| use_cache=True | ||
| ) | ||
|
|
||
| surface = tok.decode( | ||
| gen_out[0][toks["input_ids"].shape[1]:], | ||
| gen_out[0][1:], | ||
| skip_special_tokens=True | ||
| ).strip() | ||
| return surface, lp + 1, round(p, 3) |
There was a problem hiding this comment.
Generation now always feeds an extra spacer token via mdl.generate(spacer, ...), which means the generation context includes one more '=' than the number of recurrent steps actually applied (and differs from the old behavior when halting at loop 0, which generated from the prompt with zero '='). This off-by-one changes outputs and makes lp + 1 ambiguous. Align the number of spacer steps applied to the cache with the prompt variant you intend to generate from, and adjust gen_cache_pos / lp accounting accordingly.
| self.tok.pad_token = self.tok.eos_token | ||
|
|
||
| self.model = AutoModelForCausalLM.from_pretrained( | ||
| engine_dir, dtype=torch.bfloat16, |
There was a problem hiding this comment.
AutoModelForCausalLM.from_pretrained() is called with dtype=..., but the rest of the codebase uses the HuggingFace torch_dtype= kwarg for this API. If dtype is not accepted by the model class, this will raise at load time. Consider switching to torch_dtype= (and selecting torch.float32 on CPU) for consistency with session_memory.py/agent_loop.py and other scripts.
| engine_dir, dtype=torch.bfloat16, | |
| engine_dir, | |
| torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32, |
| for lp in range(max_loops): | ||
| t0 = time.perf_counter() | ||
|
|
||
| # Halting check | ||
| if self._has_head: | ||
| ln = torch.tensor([lp / max_loops], | ||
| dtype=torch.float32, device=self.device) | ||
| p_halt = self.head(torch.cat([h, ln]).unsqueeze(0)).item() | ||
|
|
||
| if verbose: | ||
| print(f" Loop {lp}: P(halt)={p_halt:.3f}") | ||
|
|
||
| if p_halt >= halt_threshold: | ||
| loop_latencies.append((time.perf_counter() - t0) * 1000) | ||
| break | ||
| elif verbose: | ||
| print(f" Loop {lp}: (no halting head)") | ||
|
|
||
| # Single-token recurrent step — O(1), no sequence growth | ||
| cache_pos = torch.tensor([seq_len + lp], device=self.device) | ||
| step_out = self.model( | ||
| input_ids=spacer, | ||
| cache_params=cache, | ||
| cache_position=cache_pos, | ||
| use_cache=True, | ||
| output_hidden_states=True | ||
| ) | ||
| # cache is mutated in-place; step_out.cache_params is the same object | ||
| h = step_out.hidden_states[-1][0, -1, :].float() | ||
|
|
||
| loop_latencies.append((time.perf_counter() - t0) * 1000) | ||
|
|
||
| # --- Generate answer from final state --- | ||
| # Pass the accumulated cache to generate. The cache already holds | ||
| # the full context (prompt + all spacer iterations). | ||
| try: | ||
| gen_cache_pos = torch.tensor( | ||
| [seq_len + lp + 1], device=self.device | ||
| ) | ||
| out_ids = self.model.generate( | ||
| input_ids=spacer, | ||
| cache_params=cache, | ||
| cache_position=gen_cache_pos, |
There was a problem hiding this comment.
The loop can break on the halting check before any recurrent spacer step is applied for that iteration, but generate() is then invoked with input_ids=spacer and cache_position=seq_len + lp + 1, which effectively applies an extra '=' even when halting at lp=0. This changes semantics vs the original engine (which could generate with zero '=' when halting immediately) and makes loop counts hard to interpret. Decide whether the halting check is evaluated before or after applying the spacer step, and adjust the step ordering / gen_cache_pos / returned lp so the cache contains exactly the intended number of spacers at generation time.
| out = model(input_ids=input_ids, use_cache=True, output_hidden_states=True) | ||
| cache = out.cache_params | ||
| spacer = torch.tensor([[spacer_id]], device=device) | ||
| for lp in range(max_loops): | ||
| cache_pos = torch.tensor([seq_len + lp], device=device) | ||
| model(input_ids=spacer, cache_params=cache, | ||
| cache_position=cache_pos, use_cache=True) | ||
|
|
||
| gen_pos = torch.tensor([seq_len + max_loops], device=device) | ||
| stat_gen = model.generate(spacer, cache_params=cache, | ||
| cache_position=gen_pos, | ||
| max_new_tokens=40, do_sample=False, | ||
| repetition_penalty=1.1, use_cache=True) | ||
| stat_text = tok.decode(stat_gen[0][1:], skip_special_tokens=True).strip() |
There was a problem hiding this comment.
The stateful generation comparison does for lp in range(max_loops): model(input_ids=spacer, ...) and then calls model.generate(spacer, cache_params=cache, ...), which adds another '=' token before generating. That makes the stateful path effectively use max_loops + 1 spacers, so the outputs and timing aren’t comparable to the original path’s prompt + '=' * max_loops. Adjust the number of manual spacer steps or the generate() input so both paths represent the same effective context length.
|
|
||
| model = AutoModelForCausalLM.from_pretrained( | ||
| engine_dir, | ||
| dtype=torch.bfloat16 if device == "cuda" else torch.float32, |
There was a problem hiding this comment.
AutoModelForCausalLM.from_pretrained() is called with dtype=.... Elsewhere in this repo, model loading consistently uses torch_dtype= for this API. If dtype isn’t supported, the benchmark script will fail at startup; switch to torch_dtype= for consistency and compatibility.
| dtype=torch.bfloat16 if device == "cuda" else torch.float32, | |
| torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, |
| ## Environment | ||
|
|
||
| - **transformers version**: 5.3.0 | ||
| - **MambaCache location**: `transformers.models.mamba.modeling_mamba` (NOT `transformers.cache_utils`) | ||
| - **Import**: `from transformers import MambaCache` (auto-imported from model module) | ||
| - **Python**: 3.14 | ||
| - **GPU**: Not available at inspection time (CPU-only system) | ||
|
|
There was a problem hiding this comment.
This doc asserts a specific transformers version (5.3.0) and describes API behavior tied to that version, but requirements.txt only specifies transformers>=4.40.0. If this PR depends on behavior introduced in 5.x (e.g., cache_position requirements or import paths), the dependency should be pinned/updated or the doc should clarify the minimum supported version to avoid runtime breakage for users following requirements.txt.
| answer = self.tok.decode( | ||
| out_ids[0][final_ids.shape[1]:], | ||
| skip_special_tokens=True | ||
| ) | ||
|
|
||
| return answer, lp, p_halt, loop_latencies |
There was a problem hiding this comment.
generate() returns lp directly, but lp is a 0-based loop index from for lp in range(max_loops). If the loop runs to completion without breaking, lp will be max_loops - 1 even though max_loops iterations were executed (and the fallback path also assumes lp + 1). Track an explicit loops_executed counter (or return lp + 1 consistently) so callers get an accurate loop count.
Summary
MambaCache=spacer tokens while passing cache state forward — O(1) per iteration, constant memory, no sequence growthmamba_engine.pyis unchanged — original training engine preservedThe Problem
This is functionally equivalent to pause tokens with a growing prompt. The SSM state is rebuilt from scratch each time — no recurrent state is carried forward.
The Fix
Each loop is a single-token recurrent step. Sequence length never grows. Memory usage is constant. This is the correct way to use an SSM recurrently.
Key API Finding
Mamba uses a different cache interface than Transformers:
past_key_values=cachecache_params=cacheout.past_key_valuesout.cache_paramscache_positionis requiredThe
cache_positiontensor shape determines prefill vs decode mode. Full details indocs/cache_api_findings.md.Benchmark Results
Mamba-2.8B (CPU, 64 layers)
Mamba-130M (CPU, 24 layers)
Files
stateful_engine.pyStatefulLoopEngineclass with O(1) cache iterationsession_memory.pylatent_turn()upgraded to O(1) cache stepsvalidate_stateful.pybenchmark_llps.pyCONTRIBUTION.mddocs/cache_api_findings.mddocs/correctness_validation.mddocs/llps_benchmark.mddocs/blockers.mdTest plan
🤖 Generated with Claude Code