fix(deepfloyd_if): remove @torch.no_grad() from encode_prompt in all IF variants#13688
Open
xodn348 wants to merge 1 commit intohuggingface:mainfrom
Open
fix(deepfloyd_if): remove @torch.no_grad() from encode_prompt in all IF variants#13688xodn348 wants to merge 1 commit intohuggingface:mainfrom
xodn348 wants to merge 1 commit intohuggingface:mainfrom
Conversation
…IF pipeline variants Removes the method-level @torch.no_grad() decorator from encode_prompt() in all six DeepFloyd IF pipeline files. The helper-level decorator prevented callers who invoke encode_prompt() directly (e.g. prompt-embedding optimisation loops, training-style workflows) from receiving gradient-tracked tensors even when they passed requires_grad=True inputs. Inference behaviour is unchanged: __call__() retains its own @torch.no_grad() context, so encode_prompt() continues to run without tracking gradients on the standard inference path. Refs huggingface#13646
yiyixuxu
reviewed
May 7, 2026
| ) | ||
| self.register_to_config(requires_safety_checker=requires_safety_checker) | ||
|
|
||
| @torch.no_grad() |
Collaborator
There was a problem hiding this comment.
Let's do a short deprecation cycle? Otherwise, people currently are using encode_prompt without torch.no_grad may not even notice this
someth8ing like
if torch.is_grad_enabled():
deprecate(...)
with torch.no_grad():
...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
All six DeepFloyd IF pipeline variants decorate their
encode_prompt()helper with@torch.no_grad(). This method-level context manager silently stripsrequires_gradfrom every tensor that passes through the function, including user-suppliedprompt_embedsthat were explicitly constructed with gradient tracking. As a result, advanced callers — prompt-embedding optimisation loops, textual-inversion fine-tuning, differentiable generation pipelines — cannot useencode_prompt()as a reusable building block because it discards the gradient graph regardless of the surrounding context.The fix is to remove
@torch.no_grad()fromencode_prompt()across all six files (pipeline_if.py,pipeline_if_img2img.py,pipeline_if_superresolution.py,pipeline_if_img2img_superresolution.py,pipeline_if_inpainting.py,pipeline_if_inpainting_superresolution.py). The standard inference path is unaffected:__call__()retains its own@torch.no_grad()decorator, which wrapsencode_prompt()during all normal inference calls. The Flux and other modern pipelines already follow this convention (e.g.src/diffusers/pipelines/flux/pipeline_flux.pydoes not annotateencode_prompt()with@torch.no_grad()).Issue
Refs #13646 (Issue 5 —
encode_prompt()detaches gradients in all IF pipelines)Local verification
Risk
The only observable change is that
encode_prompt()no longer unconditionally detaches the computation graph. Users who call it inside an activetorch.no_grad()block (the default inference path via__call__()) see no difference. Users who call it outsidetorch.no_grad()and pass gradient-trackedprompt_embedswill now receive a tensor that participates in autograd, which is the intended behaviour. There is a negligible memory overhead for callers who do not need gradients but callencode_prompt()outside anytorch.no_grad()context — they can wrap the call themselves if needed.