Skip to content

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681

Draft
gueraf wants to merge 19 commits intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache
Draft

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681
gueraf wants to merge 19 commits intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache

Conversation

@gueraf
Copy link
Copy Markdown

@gueraf gueraf commented May 5, 2026

What does this PR do?

  • Implements a (rolling) KV cache for Wan models to enable autoregressive generation.
  • Tries to mirror the KV cache pattern in transformer_flux2.py as much as possible.
  • Vidoes and byte-level equivalence against upstream Self Forcing tested in https://github.com/gueraf/self-forcing-diffusers/releases/tag/inline-rolling-kv-20260504.
  • This initial PR does not yet implement sink-frame pinning yet, and lacks some model-level adjustments (Self Forcing has cross-attention QK norms and per-frame timestep modulation).
  • Add tests for cache append/overwrite/window behavior.

Motivation

This is a tightly scoped follow-up to #12773 and a first step toward #12600. The previous draft explored similar functionality but also included Krea-specific experiments and broader integration work.

As for practical use, we (https://odyssey.ml/) would like to rely on the Hugging Face Diffusers ecosystem to ship Self-Forcing-like models without having to ship many custom modules, ideally none.

Progresses #12600

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

gueraf added 4 commits May 4, 2026 11:22
Adds inline rolling KV cache classes to transformer_wan.py, enabling
efficient autoregressive chunk-wise video generation without recomputing
previous chunks' attention context.

New public API (importable from diffusers.models.transformers.transformer_wan):
- WanRollingKVBlockCache: per-block K/V storage with optional cross-attn cache
- WanRollingKVCache: container managing all blocks, write modes, and window trimming

Usage:
    cache = WanRollingKVCache(num_blocks=len(transformer.blocks), window_size=8000)
    transformer(..., attention_kwargs={"rolling_kv_cache": cache})

Key features:
- "append" mode: grows the temporal prefix each chunk
- "overwrite" mode: writes clean K/V at a specific absolute token offset,
  allowing arbitrary temporal placement (e.g. injecting ground-truth frames)
- window_size: trims oldest tokens to keep memory bounded
- cache_cross_attention: reuses text encoder projections across chunks
- frame_offset parameter on WanRotaryPosEmbed and WanTransformer3DModel.forward
  for correct temporal RoPE positioning during chunk-wise generation

Also adds examples/inference/autoregressive_video_generation.py demonstrating
chunk-wise generation with the Self-Forcing transformer.
Tests cover WanRollingKVBlockCache and WanRollingKVCache state, both helper
functions, append mode (incremental chunk assertions), window trimming,
overwrite mode, cross-attention caching, and frame offset behavior.
Split WanAttnProcessor.__call__ into _wan_self_attention and
_wan_cross_attention module-level functions; __call__ is now a
3-line router. Also hoists apply_rotary_emb to module level.

Moves rolling KV cache tests into test_models_transformer_wan.py
(deleted test_wan_rolling_kv_cache.py). Tests now use arange-based
deterministic inputs and assert on explicit Python lists.
…rename _tok

- _get_kv_projections: new helper that projects only K and V; used in the
  cross-attention cached path instead of discarding Q from _get_qkv_projections
- _wan_self_attention / _wan_cross_attention: annotate backend and parallel_config
  as AttentionBackendName | None and ParallelConfig | None (TYPE_CHECKING imports)
- Export WanRollingKVCache and WanRollingKVBlockCache from diffusers top-level
- Rename _tok to _arange_tokens in tests
@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC labels May 5, 2026
gueraf and others added 3 commits May 5, 2026 14:42
- Drop `import unittest`
- Replace `(unittest.TestCase)` bases with plain classes
- Replace `self.assert*` calls with bare `assert` and `pytest.raises`
- Rename `setUp` → `setup_method` (pytest convention for plain classes)
- Use `torch.equal` instead of `not torch.allclose` for frame-offset
  differing-output test: the tiny model produces a real but sub-1e-5
  difference that falls within allclose's default rtol

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Drop TestWanRollingKVBlockCache/TestWanRollingKVCache (initial-state
  and configure_write validation tests — not interesting behavior)
- Drop TestCrossAttentionCache, TestFrameOffset as standalone classes
- Collapse AppendMode/WindowSize/OverwriteMode/FrameOffset into one
  TestWanRollingKVCache with _chunk/_run/_len as instance methods
- Move _arange_tokens inline as TestTrimToWindow._tok and
  TestSliceForOverwrite._bc (keeps helpers local to their users)
- Shrink module-level config to a single compact dict

35 tests → 12 tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Move _TINY_CONFIG, _NUM_BLOCKS, _TOKENS_PER_CHUNK and all three test
classes (TestTrimToWindow, TestSliceForOverwrite, TestWanRollingKVCache)
into one TestWanRollingKVCache. Config and constants are class attributes
_CONFIG/_N/_T; trim/slice helpers share _tok with the forward tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
- self.t → self.transformer
- _N/_T → NUM_BLOCKS/TOKENS_PER_CHUNK (class constants, not cryptic)
- _CONFIG expanded to one key per line; num_layers=NUM_BLOCKS makes
  the relationship explicit
- _run(*args) → explicit (latents, timestep, encoder_hidden_states)
  params; cache is now required (every caller always passes one);
  dead None-branch removed
- _len(cache=None) → _cached_len(cache) — no implicit self.cache
  default; all callers name the cache explicitly
- _tok → _filled_block_cache (describes what it actually builds)
- test_overwrite: local T = self.TOKENS_PER_CHUNK to avoid repetition

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…ce unit tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…eady -1)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…helpers

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…sert_unchanged

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…ll sites

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…servation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…nk-loop test

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…ffsets)

The old configure_write(absolute_token_offset=N) silently truncated the cache
to N+chunk_size, making mid-sequence overwrites a footgun. Restrict the API
to two well-defined modes:
- append: extend cache by chunk_size (transition to a new chunk)
- overwrite_end: drop last chunk_size tokens before appending (replace last
  chunk in place — used for subsequent denoising steps within the same chunk)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…rop-last helper

- setup_method only creates the transformer; each test creates its own cache
  with explicit window_size, making the (mode, window_size) configuration
  obvious from the test body
- Add test_append_windowed_three_chunks to exercise the rolling case where the
  window holds multiple chunks and surviving chunks shift left on eviction
- Simplify _wan_rolling_kv_drop_last_chunk to always slice (returning empty
  tensors when keep=0 instead of None), matching the previous overwrite path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant