-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[CI] Refactor Wan Model Tests #13082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco | |
| encoder_hidden_states = hidden_states | ||
|
|
||
| if attn.fused_projections: | ||
| if attn.cross_attention_dim_head is None: | ||
| if not attn.is_cross_attention: | ||
| # In self-attention layers, we can fuse the entire QKV projection into a single linear | ||
| query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) | ||
| else: | ||
|
|
@@ -502,24 +502,27 @@ def __init__( | |
| dim_head: int = 64, | ||
| eps: float = 1e-6, | ||
| cross_attention_dim_head: Optional[int] = None, | ||
| bias: bool = True, | ||
| processor=None, | ||
| ): | ||
| super().__init__() | ||
| self.inner_dim = dim_head * heads | ||
| self.heads = heads | ||
| self.cross_attention_head_dim = cross_attention_dim_head | ||
| self.cross_attention_dim_head = cross_attention_dim_head | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. Would be nice if you could explain these changes? Were these flagged by the newly written test suite?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just to keep the naming convention consistent |
||
| self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads | ||
| self.use_bias = bias | ||
| self.is_cross_attention = cross_attention_dim_head is not None | ||
|
|
||
| # 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector). | ||
| # NOTE: this is not used in "vanilla" WanAttention | ||
| self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False) | ||
| self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False) | ||
|
|
||
| # 2. QKV and Output Projections | ||
| self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) | ||
| self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) | ||
| self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) | ||
| self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True) | ||
| self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias) | ||
| self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias) | ||
| self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias) | ||
| self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias) | ||
|
|
||
| # 3. QK Norm | ||
| # NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads | ||
|
|
@@ -682,15 +685,18 @@ def __init__( | |
| self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) | ||
| self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) | ||
|
|
||
| self.is_cross_attention = cross_attention_dim_head is not None | ||
| if is_cross_attention is not None: | ||
| self.is_cross_attention = is_cross_attention | ||
| else: | ||
| self.is_cross_attention = cross_attention_dim_head is not None | ||
|
|
||
| self.set_processor(processor) | ||
|
|
||
| def fuse_projections(self): | ||
| if getattr(self, "fused_projections", False): | ||
| return | ||
|
|
||
| if self.cross_attention_dim_head is None: | ||
| if not self.is_cross_attention: | ||
| concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | ||
| concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | ||
| out_features, in_features = concatenated_weights.shape | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -176,15 +176,7 @@ def _test_quantization_inference(self, config_kwargs): | |||
| model_quantized = self._create_quantized_model(config_kwargs) | ||||
| model_quantized.to(torch_device) | ||||
|
|
||||
| # Get model dtype from first parameter | ||||
| model_dtype = next(model_quantized.parameters()).dtype | ||||
|
|
||||
| inputs = self.get_dummy_inputs() | ||||
| # Cast inputs to model dtype | ||||
| inputs = { | ||||
| k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v | ||||
| for k, v in inputs.items() | ||||
| } | ||||
|
Comment on lines
-179
to
-187
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove them?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Casting here is brittle because it's based on
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. But does it affect the existing Flux tests?
I wonder if using
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a torch_dtype property to the quantization tests and we cast the inputs directly in get_dummy_inputs. Think it's more clear this way Flux TorchAO and BnB tests will fail with this change, but I'll update the Flux2 PR to include fixes to address the change in this test.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Thanks! |
||||
| output = model_quantized(**inputs, return_dict=False)[0] | ||||
|
|
||||
| assert output is not None, "Model output is None" | ||||
|
|
@@ -229,6 +221,8 @@ def _test_quantization_lora_inference(self, config_kwargs): | |||
| init_lora_weights=False, | ||||
| ) | ||||
| model.add_adapter(lora_config) | ||||
| # Move LoRA adapter weights to device (they default to CPU) | ||||
| model.to(torch_device) | ||||
|
|
||||
| inputs = self.get_dummy_inputs() | ||||
| output = model(**inputs, return_dict=False)[0] | ||||
|
|
@@ -1021,9 +1015,6 @@ def test_gguf_dequantize(self): | |||
| """Test that dequantize() works correctly.""" | ||||
| self._test_dequantize({"compute_dtype": torch.bfloat16}) | ||||
|
|
||||
| def test_gguf_quantized_layers(self): | ||||
| self._test_quantized_layers({"compute_dtype": torch.bfloat16}) | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
|
||||
|
|
||||
| @is_quantization | ||||
| @is_modelopt | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.