-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
QwenImagePipeline cannot run with Ulysses SP together with batch prompt inputs. It is related to the mask is not correctly broadcasted.
Reproduction
With the following code snippets
import torch
import torch.distributed as dist
import argparse
import os
from diffusers import QwenImagePipeline
from diffusers import ContextParallelConfig
def parse_args():
parser = argparse.ArgumentParser(
description="Test Qwen-Image with Context Parallelism")
return parser.parse_args()
args = parse_args()
if dist.is_available():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
world_size = dist.get_world_size()
torch.cuda.set_device(device)
else:
rank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
world_size = 1
model_id = "Qwen/Qwen-Image"
pipe = QwenImagePipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
)
pipe.to(device)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
from diffusers import QwenImageTransformer2DModel
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
pipe.transformer.enable_parallelism(config=ContextParallelConfig(
ulysses_degree=world_size))
pipe.set_progress_bar_config(disable=rank != 0)
positive_magic = {
"en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
"zh": ", 超清,4K,电影级构图.", # for chinese prompt
}
prompts = [
"A coffee shop entrance features a chalkboard sign reading "
'"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
'displaying "通义千问". Next to it hangs a poster showing a '
"beautiful Chinese woman, and beneath the poster is written "
'"π≈3.1415926-53589793-23846264-33832795-02384197". '
"Ultra HD, 4K, cinematic composition",
"A cute cat with long hair sitting on a sofa, Ultra HD, 4K, cinematic composition."
]
inputs = {
"prompt": [p + positive_magic["en"] for p in prompts],
"generator": torch.Generator(device="cpu").manual_seed(0),
"true_cfg_scale": 4.0,
"negative_prompt": " ",
"num_inference_steps": 50,
"num_images_per_prompt": 1,
"height": 1024,
"width": 1024,
}
with torch.inference_mode():
output = pipe(**inputs)
for i, output_image in enumerate(output.images):
if world_size > 1:
save_path = f"output_image_ulysses{world_size}_{i}.png"
else:
save_path = f"output_image_{i}.png"
if rank == 0:
output_image.save(save_path)
print(f"image saved at {save_path}")
if dist.is_initialized():
dist.destroy_process_group()together with the following command can reproduce the issue.
torchrun --nproc_per_node=2 --local-ranks-filter=0 test.pyLogs
[rank0]: Traceback (most recent call last):
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/test.py", line 71, in <module>
[rank0]: output = pipe(**inputs)
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 681, in __call__
[rank0]: noise_pred = self.transformer(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/utils/peft_utils.py", line 315, in wrapper
[rank0]: result = forward_fn(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 953, in forward
[rank0]: encoder_hidden_states, hidden_states = block(
[rank0]: ^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]: output = function_reference.forward(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]: output = function_reference.forward(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 700, in forward
[rank0]: attn_output = self.attn(
[rank0]: ^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_processor.py", line 607, in forward
[rank0]: return self.processor(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 548, in __call__
[rank0]: joint_hidden_states = dispatch_attention_fn(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 432, in dispatch_attention_fn
[rank0]: return backend_fn(**kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 2928, in _native_attention
[rank0]: out = _templated_context_parallel_attention(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 2276, in _templated_context_parallel_attention
[rank0]: return TemplatedUlyssesAttention.apply(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni-016/lib/python3.12/site-packages/torch/autograd/function.py", line 581, in apply
[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 2005, in forward
[rank0]: out = forward_op(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/diffusers/src/diffusers/models/attention_dispatch.py", line 826, in _native_attention_forward_op
[rank0]: out = torch.nn.functional.scaled_dot_product_attention(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: The expanded size of the tensor (4222) must match the existing size (2) at non-singleton dimension 2. Target sizes: [2, 12, 4222, 4222]. Tensor sizes: [2, 4222]System Info
- 🤗 Diffusers version: 0.38.0.dev0
- Platform: Linux-5.15.0-1053-nvidia-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.9.1+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.2
- Transformers version: 4.57.6
- Accelerate version: 1.12.0
- PEFT version: 0.18.1
- Bitsandbytes version: not installed
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA H800, 81559 MiB
NVIDIA H800, 81559 MiB
NVIDIA H800, 81559 MiB
NVIDIA H800, 81559 MiB - Using GPU in script?: yes
- Using distributed or parallel set-up in script?: yes
Who can help?
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working