Skip to content

fix(deepfloyd_if): remove @torch.no_grad() from encode_prompt in all IF variants#13688

Open
xodn348 wants to merge 1 commit intohuggingface:mainfrom
xodn348:fix/deepfloyd-encode-prompt-no-grad
Open

fix(deepfloyd_if): remove @torch.no_grad() from encode_prompt in all IF variants#13688
xodn348 wants to merge 1 commit intohuggingface:mainfrom
xodn348:fix/deepfloyd-encode-prompt-no-grad

Conversation

@xodn348
Copy link
Copy Markdown

@xodn348 xodn348 commented May 7, 2026

Summary

All six DeepFloyd IF pipeline variants decorate their encode_prompt() helper with @torch.no_grad(). This method-level context manager silently strips requires_grad from every tensor that passes through the function, including user-supplied prompt_embeds that were explicitly constructed with gradient tracking. As a result, advanced callers — prompt-embedding optimisation loops, textual-inversion fine-tuning, differentiable generation pipelines — cannot use encode_prompt() as a reusable building block because it discards the gradient graph regardless of the surrounding context.

The fix is to remove @torch.no_grad() from encode_prompt() across all six files (pipeline_if.py, pipeline_if_img2img.py, pipeline_if_superresolution.py, pipeline_if_img2img_superresolution.py, pipeline_if_inpainting.py, pipeline_if_inpainting_superresolution.py). The standard inference path is unaffected: __call__() retains its own @torch.no_grad() decorator, which wraps encode_prompt() during all normal inference calls. The Flux and other modern pipelines already follow this convention (e.g. src/diffusers/pipelines/flux/pipeline_flux.py does not annotate encode_prompt() with @torch.no_grad()).

Issue

Refs #13646 (Issue 5 — encode_prompt() detaches gradients in all IF pipelines)

Local verification

$ python3 - << 'EOF'
import sys; sys.path.insert(0, 'src')
import torch, subprocess
from diffusers import DDPMScheduler, UNet2DConditionModel, IFPipeline
from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img import IFImg2ImgPipeline
from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution import IFSuperResolutionPipeline

def make_unet():
    return UNet2DConditionModel(
        sample_size=8, in_channels=3, out_channels=6, layers_per_block=1,
        block_out_channels=(8,), down_block_types=("CrossAttnDownBlock2D",),
        up_block_types=("CrossAttnUpBlock2D",), cross_attention_dim=4,
        attention_head_dim=4, norm_num_groups=1,
    )
sched = DDPMScheduler(num_train_timesteps=10, variance_type="learned_range")

# Test 1 - IFPipeline preserves grad
pipe1 = IFPipeline(None, None, make_unet(), sched, None, None, None, False)
x = torch.randn(1, 77, 4, requires_grad=True)
out, _ = pipe1.encode_prompt(prompt=None, do_classifier_free_guidance=False,
                              prompt_embeds=x, num_images_per_prompt=1)
assert out.requires_grad, "FAIL: IFPipeline strips grad"
print("IFPipeline: OK")

# Test 2 - IFImg2ImgPipeline preserves grad
pipe2 = IFImg2ImgPipeline(None, None, make_unet(), sched, None, None, None, False)
x2 = torch.randn(1, 77, 4, requires_grad=True)
out2, _ = pipe2.encode_prompt(prompt=None, do_classifier_free_guidance=False,
                               prompt_embeds=x2, num_images_per_prompt=1)
assert out2.requires_grad, "FAIL: IFImg2ImgPipeline strips grad"
print("IFImg2ImgPipeline: OK")

# Test 3 - IFSuperResolutionPipeline preserves grad
pipe3 = IFSuperResolutionPipeline(None, None, make_unet(), sched, sched, None, None, None, False)
x3 = torch.randn(1, 77, 4, requires_grad=True)
out3, _ = pipe3.encode_prompt(prompt=None, do_classifier_free_guidance=False,
                               prompt_embeds=x3, num_images_per_prompt=1)
assert out3.requires_grad, "FAIL: IFSuperResolutionPipeline strips grad"
print("IFSuperResolutionPipeline: OK")

# Test 4 - inference path (within explicit no_grad) still detaches
with torch.no_grad():
    y = torch.randn(1, 77, 4, requires_grad=True)
    out4, _ = pipe1.encode_prompt(prompt=None, do_classifier_free_guidance=False,
                                   prompt_embeds=y, num_images_per_prompt=1)
    assert not out4.requires_grad, "FAIL: no_grad context ignored"
print("inference-no_grad-context: OK")

# Test 5 - ruff
proc = subprocess.run(
    ["python3", "-m", "ruff", "check", "src/diffusers/pipelines/deepfloyd_if/"],
    capture_output=True, text=True
)
assert proc.returncode == 0, f"ruff: {proc.stdout}"
print("ruff: OK")

=== LOCAL_TEST_PASSED ===
EOF
IFPipeline: OK
IFImg2ImgPipeline: OK
IFSuperResolutionPipeline: OK
inference-no_grad-context: OK
ruff: OK
=== LOCAL_TEST_PASSED ===

Risk

The only observable change is that encode_prompt() no longer unconditionally detaches the computation graph. Users who call it inside an active torch.no_grad() block (the default inference path via __call__()) see no difference. Users who call it outside torch.no_grad() and pass gradient-tracked prompt_embeds will now receive a tensor that participates in autograd, which is the intended behaviour. There is a negligible memory overhead for callers who do not need gradients but call encode_prompt() outside any torch.no_grad() context — they can wrap the call themselves if needed.

…IF pipeline variants

Removes the method-level @torch.no_grad() decorator from encode_prompt() in
all six DeepFloyd IF pipeline files.  The helper-level decorator prevented
callers who invoke encode_prompt() directly (e.g. prompt-embedding
optimisation loops, training-style workflows) from receiving gradient-tracked
tensors even when they passed requires_grad=True inputs.

Inference behaviour is unchanged: __call__() retains its own @torch.no_grad()
context, so encode_prompt() continues to run without tracking gradients on the
standard inference path.

Refs huggingface#13646
@github-actions github-actions Bot added size/S PR with diff < 50 LOC pipelines labels May 7, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR

)
self.register_to_config(requires_safety_checker=requires_safety_checker)

@torch.no_grad()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do a short deprecation cycle? Otherwise, people currently are using encode_prompt without torch.no_grad may not even notice this

someth8ing like

if torch.is_grad_enabled():
    deprecate(...)
with torch.no_grad():
   ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pipelines size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants