Cache ModelMixin.dtype to avoid named_parameters walk per access#13571
Cache ModelMixin.dtype to avoid named_parameters walk per access#13571akshan-main wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
Profiled SD3 too (eager + compile, RTX PRO 6000 Blackwell, 2 steps) following the profiling guide. Denoising loop is clean. 0 syncs in Pre-loop has 2x ~10ms Tested adding
The sync was queue-drain. GPU has to do that work anyway, CPU just doesn't wait for it. Unlike Z-Image #13461, no per-step |
What does this PR do?
Addresses #13401
ModelMixin.dtypecallsget_parameter_dtype()which walksnamed_parameters()on every access. Pipelines callself.transformer.dtype/self.text_encoder.dtype/self.vae.dtypeinside their denoise loops, so the walk fires every step.This PR caches the dtype on first access and invalidates via
_apply(which.to(),.cpu(),.cuda(),.half(),.bfloat16()etc. all flow through). One small change benefits every pipeline that subclassesModelMixin.deviceis intentionally not cached: with group offloading, the effective device changes per-forward as groups onload/offload. Caching it would break that flow.Same shape of fix as the centralized
cache_context._set_contextcache in #13356.torch.dtypevalueget_parameter_dtype()would return; generation outputs are bit-identical..to()/.cpu()/.cuda()/.half()/.bfloat16()all flow throughnn.Module._apply, so the cache is invalidated correctly when the actual dtype changes.AutoencoderKL: 87.81us → 0.09us per.dtypeaccess (963x).Profiling - surveyed across 10 pipelines (eager, 2 inference steps, H100)
The fix removes the walk wherever it appears (most impact on
hunyuanv15: 30.95ms at 2 inference steps; scales linearly withnum_inference_steps). On pipelines where the walk doesn't appear (chroma, ltx2), there is no regression. Fix is a no-op there.Reproduction notebook (Colab) - applies the central fix, profiles every pipeline before and after, consolidated table at bottom of notebook.
Before submitting
Who can review?
@sayakpaul @dg845