Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/scope/core/pipelines/wan2_1/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,23 @@ def parse_lora_weights(
f"parse_lora_weights: Matched base_key='{base_key}' -> model_key='{model_key}'"
)

# Validate LoRA dimensions against the model weight before injecting.
# lora_A shape: [rank, in_features] — in_features must match model weight dim 1
# lora_B shape: [out_features, rank] — out_features must match model weight dim 0
# (model weight shape is [out_features, in_features] for nn.Linear)
model_weight = model_state.get(model_key)
if model_weight is not None and lora_A.ndim == 2 and lora_B.ndim == 2:
lora_in = lora_A.shape[1] # LoRA expects this input dimension
lora_out = lora_B.shape[0] # LoRA expects this output dimension
model_out, model_in = model_weight.shape[0], model_weight.shape[1]
if lora_in != model_in or lora_out != model_out:
raise ValueError(
f"LoRA dimension mismatch at layer '{base_key}': "
f"LoRA expects ({lora_out}×{lora_in}) but model layer is ({model_out}×{model_in}). "
f"This LoRA was likely trained for a different model size (e.g. Wan2.1-5B vs 1.3B). "
f"Please use a LoRA that matches the loaded model architecture."
)

# Extract alpha and rank
alpha = None
if alpha_key and alpha_key in lora_state:
Expand Down
87 changes: 87 additions & 0 deletions tests/test_lora_dimension_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests for LoRA dimension validation in parse_lora_weights.

Regression test for issue #922: a LoRA trained for Wan2.1-5B (in_features=5120)
was silently loaded into the Wan2.1-1.3B model (in_features=1536) and only
failed 156 times at inference time with an inscrutable RuntimeError.
"""

import pytest
import torch

from scope.core.pipelines.wan2_1.lora.utils import parse_lora_weights


def _make_model_state(in_features: int, out_features: int = 256) -> dict:
"""Minimal model state dict with one linear layer."""
return {
"blocks.0.self_attn.q.weight": torch.zeros(out_features, in_features),
}


def _make_lora_state(rank: int, in_features: int, out_features: int = 256) -> dict:
"""Minimal PEFT-format LoRA state targeting the same layer."""
return {
"diffusion_model.blocks.0.self_attn.q.lora_A.weight": torch.zeros(rank, in_features),
"diffusion_model.blocks.0.self_attn.q.lora_B.weight": torch.zeros(out_features, rank),
}


class TestLoRADimensionValidation:
"""Verify parse_lora_weights raises a clear error on dimension mismatch."""

def test_compatible_lora_loads_successfully(self):
"""LoRA matching the model's dimensions should parse without error."""
model_state = _make_model_state(in_features=1536)
lora_state = _make_lora_state(rank=32, in_features=1536)

mapping = parse_lora_weights(lora_state, model_state)

assert len(mapping) == 1
key = "blocks.0.self_attn.q.weight"
assert key in mapping
assert mapping[key]["rank"] == 32

def test_incompatible_lora_raises_value_error(self):
"""LoRA trained for 5B (in_features=5120) must not silently load into 1.3B (in_features=1536)."""
model_state = _make_model_state(in_features=1536) # 1.3B model
lora_state = _make_lora_state(rank=32, in_features=5120) # 5B LoRA

with pytest.raises(ValueError, match="LoRA dimension mismatch"):
parse_lora_weights(lora_state, model_state)

def test_error_message_is_user_friendly(self):
"""The error message should name the layer and the dimension sizes."""
model_state = _make_model_state(in_features=1536)
lora_state = _make_lora_state(rank=32, in_features=5120)

with pytest.raises(ValueError) as exc_info:
parse_lora_weights(lora_state, model_state)

msg = str(exc_info.value)
assert "blocks.0.self_attn.q" in msg, "Layer name should appear in error"
assert "5120" in msg, "LoRA in_features should appear in error"
assert "1536" in msg, "Model in_features should appear in error"
assert "model size" in msg.lower() or "architecture" in msg.lower(), (
"Error should hint at model size mismatch"
)

def test_out_features_mismatch_also_caught(self):
"""LoRA with wrong output dimension should also be rejected."""
model_state = _make_model_state(in_features=1536, out_features=256)
# LoRA with matching in_features but wrong out_features
lora_state = {
"diffusion_model.blocks.0.self_attn.q.lora_A.weight": torch.zeros(32, 1536),
"diffusion_model.blocks.0.self_attn.q.lora_B.weight": torch.zeros(512, 32), # wrong
}

with pytest.raises(ValueError, match="LoRA dimension mismatch"):
parse_lora_weights(lora_state, model_state)

def test_compatible_5b_lora_on_5b_model(self):
"""LoRA trained for 5B on a 5B model should load fine."""
model_state = _make_model_state(in_features=5120, out_features=5120)
lora_state = _make_lora_state(rank=32, in_features=5120, out_features=5120)

mapping = parse_lora_weights(lora_state, model_state)

assert len(mapping) == 1
Loading