Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
A script to convert DiffSynth-Studio Blockwise ControlNet checkpoints to the Diffusers format.

The DiffSynth checkpoints only contain the ControlNet-specific weights (controlnet_blocks + img_in).
The transformer backbone weights are loaded from the base Qwen-Image model.

Example:
Convert using HuggingFace repo:
```bash
python scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py \
--original_state_dict_repo_id "DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny" \
--filename "model.safetensors" \
--transformer_repo_id "Qwen/Qwen-Image" \
--output_path "output/qwenimage-blockwise-controlnet-canny" \
--dtype "bf16"
```

Or convert from a local file:
```bash
python scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py \
--checkpoint_path "path/to/model.safetensors" \
--transformer_repo_id "Qwen/Qwen-Image" \
--output_path "output/qwenimage-blockwise-controlnet-canny" \
--dtype "bf16"
```

Note:
Available DiffSynth blockwise ControlNet checkpoints:
- DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny
- DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth
- DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint
"""

import argparse

import safetensors.torch
import torch
from huggingface_hub import hf_hub_download

from diffusers import QwenImageControlNetModel, QwenImageTransformer2DModel


parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
parser.add_argument(
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID for the blockwise checkpoint"
)
parser.add_argument("--filename", type=str, default="model.safetensors", help="Filename in the HF repo")
parser.add_argument(
"--transformer_repo_id",
type=str,
default="Qwen/Qwen-Image",
help="HuggingFace repo ID for the base transformer model",
)
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
parser.add_argument(
"--dtype", type=str, default="bf16", help="Data type for the converted model (fp16, bf16, or fp32)"
)

args = parser.parse_args()


def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")

original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict


def main(args):
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")

# Load base transformer
print(f"Loading base transformer from {args.transformer_repo_id}...")
transformer = QwenImageTransformer2DModel.from_pretrained(
args.transformer_repo_id, subfolder="transformer", torch_dtype=dtype
)

# Create controlnet from transformer (copies backbone weights)
print("Creating blockwise ControlNet from transformer...")
controlnet = QwenImageControlNetModel.from_transformer(
transformer,
num_layers=transformer.config.num_layers,
attention_head_dim=transformer.config.attention_head_dim,
num_attention_heads=transformer.config.num_attention_heads,
controlnet_block_type="blockwise",
)

# Load DiffSynth blockwise weights (controlnet_blocks + img_in only)
original_ckpt = load_original_checkpoint(args)
missing, unexpected = controlnet.load_state_dict(original_ckpt, strict=False)

# Verify: only transformer backbone keys should be missing, no unexpected keys
print(f"Missing keys (expected - backbone from transformer): {len(missing)}")
print(f"Unexpected keys (should be 0): {len(unexpected)}")
if unexpected:
print(f"WARNING: Unexpected keys found: {unexpected}")

# Free transformer memory
del transformer

print(f"Saving blockwise ControlNet in Diffusers format to {args.output_path}")
controlnet.to(dtype).save_pretrained(args.output_path)
print("Done!")


if __name__ == "__main__":
main(args)
51 changes: 44 additions & 7 deletions src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
# Copyright 2025 Black Forest Labs, The HuggingFace Team, The InstantX Team and The DiffSynth Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +44,29 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class BlockWiseControlBlock(nn.Module):
"""A control block that fuses base hidden states with control features via RMSNorm + MLP.

Used by the DiffSynth blockwise ControlNet variant. Unlike the linear projection used by InstantX,
this block normalizes both inputs separately before fusing them through a gated linear projection.
"""

def __init__(self, dim: int = 3072):
super().__init__()
self.x_rms = RMSNorm(dim, eps=1e-6)
self.y_rms = RMSNorm(dim, eps=1e-6)
self.input_proj = nn.Linear(dim, dim)
self.act = nn.GELU()
self.output_proj = zero_module(nn.Linear(dim, dim))

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x


@dataclass
class QwenImageControlNetOutput(BaseOutput):
controlnet_block_samples: tuple[torch.Tensor]
Expand All @@ -65,6 +89,7 @@ def __init__(
joint_attention_dim: int = 3584,
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
extra_condition_channels: int = 0, # for controlnet-inpainting
controlnet_block_type: str = "linear",
):
super().__init__()
self.out_channels = out_channels or in_channels
Expand Down Expand Up @@ -92,8 +117,12 @@ def __init__(

# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
if controlnet_block_type == "blockwise":
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(BlockWiseControlBlock(self.inner_dim))
else:
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_x_embedder = zero_module(
torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
)
Expand All @@ -109,12 +138,14 @@ def from_transformer(
num_attention_heads: int = 24,
load_weights_from_transformer=True,
extra_condition_channels: int = 0,
controlnet_block_type: str = "linear",
):
config = dict(transformer.config)
config["num_layers"] = num_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
config["extra_condition_channels"] = extra_condition_channels
config["controlnet_block_type"] = controlnet_block_type

controlnet = cls.from_config(config)

Expand Down Expand Up @@ -190,7 +221,8 @@ def forward(
hidden_states = self.img_in(hidden_states)

# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
controlnet_cond_embed = self.controlnet_x_embedder(controlnet_cond)
hidden_states = hidden_states + controlnet_cond_embed

temb = self.time_text_embed(timestep, hidden_states)

Expand Down Expand Up @@ -240,9 +272,14 @@ def forward(

# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
if self.config.controlnet_block_type == "blockwise":
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample, controlnet_cond_embed)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
else:
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)

# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
Expand Down
51 changes: 51 additions & 0 deletions tests/pipelines/qwenimage/test_qwenimage_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,57 @@ def test_attention_slicing_forward_pass(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)

def test_qwen_blockwise_controlnet(self):
device = "cpu"
components = self.get_dummy_components()

# Replace controlnet with blockwise variant
torch.manual_seed(0)
blockwise_controlnet = QwenImageControlNetModel(
patch_size=2,
in_channels=16,
out_channels=4,
num_layers=2,
attention_head_dim=16,
num_attention_heads=3,
joint_attention_dim=16,
axes_dims_rope=(8, 4, 4),
controlnet_block_type="blockwise",
)
components["controlnet"] = blockwise_controlnet

pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))

def test_qwen_blockwise_controlnet_from_transformer(self):
device = "cpu"
components = self.get_dummy_components()
transformer = components["transformer"]

blockwise_controlnet = QwenImageControlNetModel.from_transformer(
transformer,
num_layers=2,
attention_head_dim=16,
num_attention_heads=3,
controlnet_block_type="blockwise",
)
components["controlnet"] = blockwise_controlnet

pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))

def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
Expand Down