Skip to content

Cache ModelMixin.dtype to avoid named_parameters walk per access#13571

Open
akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main:cache-modelmixin-dtype
Open

Cache ModelMixin.dtype to avoid named_parameters walk per access#13571
akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main:cache-modelmixin-dtype

Conversation

@akshan-main
Copy link
Copy Markdown
Contributor

@akshan-main akshan-main commented Apr 28, 2026

What does this PR do?

Addresses #13401

ModelMixin.dtype calls get_parameter_dtype() which walks named_parameters() on every access. Pipelines call self.transformer.dtype / self.text_encoder.dtype / self.vae.dtype inside their denoise loops, so the walk fires every step.

This PR caches the dtype on first access and invalidates via _apply (which .to(), .cpu(), .cuda(), .half(), .bfloat16() etc. all flow through). One small change benefits every pipeline that subclasses ModelMixin.

device is intentionally not cached: with group offloading, the effective device changes per-forward as groups onload/offload. Caching it would break that flow.

Same shape of fix as the centralized cache_context._set_context cache in #13356.

  • Cache returns the same torch.dtype value get_parameter_dtype() would return; generation outputs are bit-identical.
  • .to() / .cpu() / .cuda() / .half() / .bfloat16() all flow through nn.Module._apply, so the cache is invalidated correctly when the actual dtype changes.
  • Microbench on AutoencoderKL: 87.81us → 0.09us per .dtype access (963x).

Profiling - surveyed across 10 pipelines (eager, 2 inference steps, H100)

pipeline                | get_parameter_dtype calls | get_parameter_dtype total (ms) | inter-step gap (ms)  | pipeline_call (ms)

                        |   before        after     |    before        after         |  before    after     |  before     after

flux2                   |        2            0     |     2.20         0.00          |   1.34     0.28      |  191.76    185.81
qwenimage(full_decode)  |        1            0     |     1.07         0.00          |   0.09     0.09      | 1430.86   1421.92
qwenimage_edit(fd)      |        1            0     |     1.04         0.00          |   0.12     0.13      | 3771.11   3772.56
z_image                 |        2            0     |     4.21         0.00          |   5.46     3.51      |  852.15    845.42
chroma                  |        0            0     |     0.00         0.00          |   0.05     0.06      | 1501.74   1499.49
sdxl(full_decode)       |        2            0     |     1.80         0.00          |   0.00     0.00      |  328.60    330.17
sana                    |        1            0     |     1.47         0.00          |  21.44    21.52      |  132.70    130.65
hunyuanv15 (video)      |        5            0     |    30.95         0.00          |   0.09     0.08      |12884.55  12862.06
wan2.2 (video)          |        1            0     |     2.79         0.00          |   0.06     0.06      | 1891.21   1866.68
ltx2 (video)            |        0            0     |     0.00         0.00          | 166.77   166.82      | 3682.21   3679.66

The fix removes the walk wherever it appears (most impact on hunyuanv15: 30.95ms at 2 inference steps; scales linearly with num_inference_steps). On pipelines where the walk doesn't appear (chroma, ltx2), there is no regression. Fix is a no-op there.

Reproduction notebook (Colab) - applies the central fix, profiles every pipeline before and after, consolidated table at bottom of notebook.

Before submitting

  • This PR fixes a typo or improves the docs.
  • Did you read the contributor guideline?
  • Did you read our philosophy doc?
  • Was this discussed/approved via a GitHub issue or the forum?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@sayakpaul @dg845

@akshan-main
Copy link
Copy Markdown
Contributor Author

akshan-main commented Apr 28, 2026

Profiled SD3 too (eager + compile, RTX PRO 6000 Blackwell, 2 steps) following the profiling guide.
Notebook: https://colab.research.google.com/gist/akshan-main/cb9ee83575806704e93e03496ba0d940/sd3_profiling.ipynb

Denoising loop is clean. 0 syncs in in_loop_body, in_transformer_forward, or in_scheduler_step after set_begin_index(0).

Pre-loop has 2x ~10ms aten::copy_ from scheduler.set_timesteps (numpy to GPU sigmas) and _get_clip_prompt_embeds (tokenizer ids to GPU). One 62ms aten::nonzero in the first _init_step_index call which set_begin_index(0) eliminates.

Tested adding set_begin_index(0) (matches Flux/Wan/Flux2). Trace sync drops from 62ms to 0 but wall-clock is within noise:

Mode Before After Delta
Eager 233.0 ± 0.6 ms 231.9 ± 1.0 ms -1.1 ms
Compile 200.2 ± 0.3 ms 199.9 ± 0.3 ms -0.3 ms

The sync was queue-drain. GPU has to do that work anyway, CPU just doesn't wait for it. Unlike Z-Image #13461, no per-step .item()/.cpu() to chase here. Remaining pre-loop syncs are legitimate one-time copies. Not opening a PR for SD3 profile.

@sayakpaul @dg845

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu April 30, 2026 06:35
@akshan-main
Copy link
Copy Markdown
Contributor Author

akshan-main commented May 1, 2026

friendly ping @DN6 @yiyixuxu would love a review today if time permits for you guys today!

@akshan-main
Copy link
Copy Markdown
Contributor Author

friendly ping @yiyixuxu @DN6 hey guys can this be reviewed if possible

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant