Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681
Draft
gueraf wants to merge 19 commits intohuggingface:mainfrom
Draft
Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681gueraf wants to merge 19 commits intohuggingface:mainfrom
gueraf wants to merge 19 commits intohuggingface:mainfrom
Conversation
Adds inline rolling KV cache classes to transformer_wan.py, enabling
efficient autoregressive chunk-wise video generation without recomputing
previous chunks' attention context.
New public API (importable from diffusers.models.transformers.transformer_wan):
- WanRollingKVBlockCache: per-block K/V storage with optional cross-attn cache
- WanRollingKVCache: container managing all blocks, write modes, and window trimming
Usage:
cache = WanRollingKVCache(num_blocks=len(transformer.blocks), window_size=8000)
transformer(..., attention_kwargs={"rolling_kv_cache": cache})
Key features:
- "append" mode: grows the temporal prefix each chunk
- "overwrite" mode: writes clean K/V at a specific absolute token offset,
allowing arbitrary temporal placement (e.g. injecting ground-truth frames)
- window_size: trims oldest tokens to keep memory bounded
- cache_cross_attention: reuses text encoder projections across chunks
- frame_offset parameter on WanRotaryPosEmbed and WanTransformer3DModel.forward
for correct temporal RoPE positioning during chunk-wise generation
Also adds examples/inference/autoregressive_video_generation.py demonstrating
chunk-wise generation with the Self-Forcing transformer.
Tests cover WanRollingKVBlockCache and WanRollingKVCache state, both helper functions, append mode (incremental chunk assertions), window trimming, overwrite mode, cross-attention caching, and frame offset behavior.
Split WanAttnProcessor.__call__ into _wan_self_attention and _wan_cross_attention module-level functions; __call__ is now a 3-line router. Also hoists apply_rotary_emb to module level. Moves rolling KV cache tests into test_models_transformer_wan.py (deleted test_wan_rolling_kv_cache.py). Tests now use arange-based deterministic inputs and assert on explicit Python lists.
…rename _tok - _get_kv_projections: new helper that projects only K and V; used in the cross-attention cached path instead of discarding Q from _get_qkv_projections - _wan_self_attention / _wan_cross_attention: annotate backend and parallel_config as AttentionBackendName | None and ParallelConfig | None (TYPE_CHECKING imports) - Export WanRollingKVCache and WanRollingKVBlockCache from diffusers top-level - Rename _tok to _arange_tokens in tests
- Drop `import unittest` - Replace `(unittest.TestCase)` bases with plain classes - Replace `self.assert*` calls with bare `assert` and `pytest.raises` - Rename `setUp` → `setup_method` (pytest convention for plain classes) - Use `torch.equal` instead of `not torch.allclose` for frame-offset differing-output test: the tiny model produces a real but sub-1e-5 difference that falls within allclose's default rtol Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Drop TestWanRollingKVBlockCache/TestWanRollingKVCache (initial-state and configure_write validation tests — not interesting behavior) - Drop TestCrossAttentionCache, TestFrameOffset as standalone classes - Collapse AppendMode/WindowSize/OverwriteMode/FrameOffset into one TestWanRollingKVCache with _chunk/_run/_len as instance methods - Move _arange_tokens inline as TestTrimToWindow._tok and TestSliceForOverwrite._bc (keeps helpers local to their users) - Shrink module-level config to a single compact dict 35 tests → 12 tests Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Move _TINY_CONFIG, _NUM_BLOCKS, _TOKENS_PER_CHUNK and all three test classes (TestTrimToWindow, TestSliceForOverwrite, TestWanRollingKVCache) into one TestWanRollingKVCache. Config and constants are class attributes _CONFIG/_N/_T; trim/slice helpers share _tok with the forward tests. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- self.t → self.transformer - _N/_T → NUM_BLOCKS/TOKENS_PER_CHUNK (class constants, not cryptic) - _CONFIG expanded to one key per line; num_layers=NUM_BLOCKS makes the relationship explicit - _run(*args) → explicit (latents, timestep, encoder_hidden_states) params; cache is now required (every caller always passes one); dead None-branch removed - _len(cache=None) → _cached_len(cache) — no implicit self.cache default; all callers name the cache explicitly - _tok → _filled_block_cache (describes what it actually builds) - test_overwrite: local T = self.TOKENS_PER_CHUNK to avoid repetition Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ce unit tests Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…eady -1) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…helpers Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…sert_unchanged Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ll sites Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…servation Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…nk-loop test Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ffsets) The old configure_write(absolute_token_offset=N) silently truncated the cache to N+chunk_size, making mid-sequence overwrites a footgun. Restrict the API to two well-defined modes: - append: extend cache by chunk_size (transition to a new chunk) - overwrite_end: drop last chunk_size tokens before appending (replace last chunk in place — used for subsequent denoising steps within the same chunk) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rop-last helper - setup_method only creates the transformer; each test creates its own cache with explicit window_size, making the (mode, window_size) configuration obvious from the test body - Add test_append_windowed_three_chunks to exercise the rolling case where the window holds multiple chunks and surviving chunks shift left on eviction - Simplify _wan_rolling_kv_drop_last_chunk to always slice (returning empty tensors when keep=0 instead of None), matching the previous overwrite path Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Motivation
This is a tightly scoped follow-up to #12773 and a first step toward #12600. The previous draft explored similar functionality but also included Krea-specific experiments and broader integration work.
As for practical use, we (https://odyssey.ml/) would like to rely on the Hugging Face Diffusers ecosystem to ship Self-Forcing-like models without having to ship many custom modules, ideally none.
Progresses #12600
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.