Skip to content

QwenImagePipeline failed with Ulysses SP and batch inputs #13277

@zhtmike

Description

@zhtmike

Describe the bug

QwenImagePipeline cannot run with Ulysses SP together with batch prompt inputs. It is related to the mask is not correctly broadcasted.

Reproduction

With the following code snippets

import torch
import torch.distributed as dist
import argparse
import os
from diffusers import QwenImagePipeline
from diffusers import ContextParallelConfig


def parse_args():
    parser = argparse.ArgumentParser(
        description="Test Qwen-Image with Context Parallelism")
    return parser.parse_args()


args = parse_args()

if dist.is_available():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

model_id = "Qwen/Qwen-Image"

pipe = QwenImagePipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
)
pipe.to(device)

pipe.transformer.set_attention_backend("native")
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(config=ContextParallelConfig(
        ulysses_degree=world_size))

pipe.set_progress_bar_config(disable=rank != 0)

positive_magic = {
    "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
    "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}
prompts = [
    "A coffee shop entrance features a chalkboard sign reading "
    '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
    'displaying "通义千问". Next to it hangs a poster showing a '
    "beautiful Chinese woman, and beneath the poster is written "
    '"π≈3.1415926-53589793-23846264-33832795-02384197". '
    "Ultra HD, 4K, cinematic composition",
    "A cute cat with long hair sitting on a sofa, Ultra HD, 4K, cinematic composition."
]

inputs = {
    "prompt": [p + positive_magic["en"] for p in prompts],
    "generator": torch.Generator(device="cpu").manual_seed(0),
    "true_cfg_scale": 4.0,
    "negative_prompt": " ",
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
    "height": 1024,
    "width": 1024,
}

with torch.inference_mode():
    output = pipe(**inputs)
    for i, output_image in enumerate(output.images):
        if world_size > 1:
            save_path = f"output_image_ulysses{world_size}_{i}.png"
        else:
            save_path = f"output_image_{i}.png"
        if rank == 0:
            output_image.save(save_path)
            print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()

together with the following command can reproduce the issue.

torchrun --nproc_per_node=2 --local-ranks-filter=0 test.py

Logs

[rank0]: Traceback (most recent call last):
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/test.py", line 71, in <module>
[rank0]:     output = pipe(**inputs)
[rank0]:              ^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 681, in __call__
[rank0]:     noise_pred = self.transformer(
[rank0]:                  ^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/utils/peft_utils.py", line 315, in wrapper
[rank0]:     result = forward_fn(self, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 953, in forward
[rank0]:     encoder_hidden_states, hidden_states = block(
[rank0]:                                            ^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 700, in forward
[rank0]:     attn_output = self.attn(
[rank0]:                   ^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_processor.py", line 607, in forward
[rank0]:     return self.processor(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 548, in __call__
[rank0]:     joint_hidden_states = dispatch_attention_fn(
[rank0]:                           ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 432, in dispatch_attention_fn
[rank0]:     return backend_fn(**kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 2928, in _native_attention
[rank0]:     out = _templated_context_parallel_attention(
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 2276, in _templated_context_parallel_attention
[rank0]:     return TemplatedUlyssesAttention.apply(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/autograd/function.py", line 581, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 2005, in forward
[rank0]:     out = forward_op(
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 826, in _native_attention_forward_op
[rank0]:     out = torch.nn.functional.scaled_dot_product_attention(
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: The expanded size of the tensor (4222) must match the existing size (2) at non-singleton dimension 2.  Target sizes: [2, 12, 4222, 4222].  Tensor sizes: [2, 4222]

System Info

  • 🤗 Diffusers version: 0.38.0.dev0
  • Platform: Linux-5.15.0-1053-nvidia-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.12.12
  • PyTorch version (GPU?): 2.9.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.2
  • Transformers version: 4.57.6
  • Accelerate version: 1.12.0
  • PEFT version: 0.18.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA H800, 81559 MiB
    NVIDIA H800, 81559 MiB
    NVIDIA H800, 81559 MiB
    NVIDIA H800, 81559 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions