Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/diffusers/models/transformers/transformer_chronoedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states

if attn.fused_projections:
if attn.cross_attention_dim_head is None:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
Expand Down Expand Up @@ -219,15 +219,18 @@ def __init__(
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)

self.is_cross_attention = cross_attention_dim_head is not None
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None

self.set_processor(processor)

def fuse_projections(self):
if getattr(self, "fused_projections", False):
return

if self.cross_attention_dim_head is None:
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states

if attn.fused_projections:
if attn.cross_attention_dim_head is None:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
Expand Down Expand Up @@ -214,15 +214,18 @@ def __init__(
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)

self.is_cross_attention = cross_attention_dim_head is not None
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None

self.set_processor(processor)

def fuse_projections(self):
if getattr(self, "fused_projections", False):
return

if self.cross_attention_dim_head is None:
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
Expand Down
22 changes: 14 additions & 8 deletions src/diffusers/models/transformers/transformer_wan_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states

if attn.fused_projections:
if attn.cross_attention_dim_head is None:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
Expand Down Expand Up @@ -502,24 +502,27 @@ def __init__(
dim_head: int = 64,
eps: float = 1e-6,
cross_attention_dim_head: Optional[int] = None,
bias: bool = True,
processor=None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_head_dim = cross_attention_dim_head
self.cross_attention_dim_head = cross_attention_dim_head
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Would be nice if you could explain these changes? Were these flagged by the newly written test suite?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to keep the naming convention consistent

self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.use_bias = bias
self.is_cross_attention = cross_attention_dim_head is not None

# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
# NOTE: this is not used in "vanilla" WanAttention
self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)

# 2. QKV and Output Projections
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias)

# 3. QK Norm
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
Expand Down Expand Up @@ -682,15 +685,18 @@ def __init__(
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)

self.is_cross_attention = cross_attention_dim_head is not None
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None

self.set_processor(processor)

def fuse_projections(self):
if getattr(self, "fused_projections", False):
return

if self.cross_attention_dim_head is None:
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/transformer_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
processor=WanAttnProcessor(),
is_cross_attention=True,
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()

Expand Down Expand Up @@ -178,6 +179,7 @@ class WanVACETransformer3DModel(
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"]

@register_to_config
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/gguf/gguf_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, quantization_config, **kwargs):

self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = quantization_config.pre_quantized
self.modules_to_not_convert = quantization_config.modules_to_not_convert
self.modules_to_not_convert = quantization_config.modules_to_not_convert or []

if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
Expand Down
15 changes: 6 additions & 9 deletions tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,16 +446,17 @@ def test_getattr_is_correct(self, caplog):
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self):
def test_keep_in_fp32_modules(self, tmp_path):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules

if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")

# Test with float16
model.to(torch_device)
model.to(torch.float16)
# Save the model and reload with float16 dtype
# _keep_in_fp32_modules is only enforced during from_pretrained loading
model.save_pretrained(tmp_path)
model = self.model_class.from_pretrained(tmp_path, torch_dtype=torch.float16).to(torch_device)

for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
Expand All @@ -470,7 +471,7 @@ def test_keep_in_fp32_modules(self):
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []
Expand All @@ -490,10 +491,6 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]

self._check_dtype_inference_output(output, output_loaded, dtype)

def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
"""Check dtype inference output with configurable tolerance."""
assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
)
Expand Down
13 changes: 2 additions & 11 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,7 @@ def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
model_quantized.to(torch_device)

# Get model dtype from first parameter
model_dtype = next(model_quantized.parameters()).dtype

inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
Comment on lines -179 to -187
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting here is brittle because it's based on model_dtype which we get from model_dtype = next(model_quantized.parameters()).dtype. This can lead to different dtypes across different models and different quantization schemes. e.g With Flux + GGUF the test passes because the parameter dtype is the same the input dtype (bfloat16). However with Wan it fails because the parameter dtype is int8.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. But does it affect the existing Flux tests?

Casting here is brittle because it's based on model_dtype which we get from model_dtype = next(model_quantized.parameters()).dtype

I wonder if using .dtype on a model subclassed from ModelMixin would alleviate this problem because dtype implementation is quite elaborate:

def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a torch_dtype property to the quantization tests and we cast the inputs directly in get_dummy_inputs. Think it's more clear this way

Flux TorchAO and BnB tests will fail with this change, but I'll update the Flux2 PR to include fixes to address the change in this test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks!

output = model_quantized(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None"
Expand Down Expand Up @@ -229,6 +221,8 @@ def _test_quantization_lora_inference(self, config_kwargs):
init_lora_weights=False,
)
model.add_adapter(lora_config)
# Move LoRA adapter weights to device (they default to CPU)
model.to(torch_device)

inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
Expand Down Expand Up @@ -1021,9 +1015,6 @@ def test_gguf_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize({"compute_dtype": torch.bfloat16})

def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})


@is_quantization
@is_modelopt
Expand Down
Loading
Loading