Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,18 +2455,22 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
if has_diffusion_model:
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}

has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") or k.startswith("lora_unet__") for k in state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good lord. Double _ 🥲

if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.removeprefix("lora_unet__").removeprefix("lora_unet_"): v for k, v in state_dict.items()}

def convert_key(key: str) -> str:
# ZImage has: layers, noise_refiner, context_refiner blocks
# Keys may be like: layers_0_attention_to_q.lora_down.weight

if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""
suffix = ""
for sfx in (".lora_down.weight", ".lora_up.weight", ".alpha"):
if key.endswith(sfx):
base = key[: -len(sfx)]
suffix = sfx
break
else:
base = key
Comment on lines -2466 to +2473
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope this does not break compatibility with existing LoRAs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried with multiple loras from AI-Toolkit and they all worked, did we support kohya ones? I didn't test with those because I couldn't find any on the hub


# Protected n-grams that must keep their internal underscores
protected = {
Expand All @@ -2477,6 +2481,9 @@ def convert_key(key: str) -> str:
("to", "out"),
# feed_forward
("feed", "forward"),
# noise and context refiner
("noise", "refiner"),
("context", "refiner"),
}

prot_by_len = {}
Expand All @@ -2501,7 +2508,7 @@ def convert_key(key: str) -> str:
i += 1

converted_base = ".".join(merged)
return converted_base + (("." + suffix) if suffix else "")
return converted_base + suffix

state_dict = {convert_key(k): v for k, v in state_dict.items()}

Expand Down
Loading