Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@
from peft.utils import get_peft_model_state_dict


def _transformers_strips_text_model_prefix() -> bool:
"""
transformers>=5.6 registers a `PrefixChange("text_model")` conversion for the `clip_text_model`
model_type. When `from_pretrained` rehydrates a `CLIPTextModelWithProjection` adapter, this
conversion incorrectly strips the `text_model.` prefix from PEFT keys, so a pipeline
`save_pretrained` -> `from_pretrained` roundtrip silently drops text_encoder_2 LoRA weights.
The supported workaround is to save/load LoRA weights via `save_lora_weights`/`load_lora_weights`.
"""
try:
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
from transformers.core_model_loading import PrefixChange
except ImportError:
return False
mapping = get_checkpoint_conversion_mapping("clip_text_model") or []
return any(isinstance(c, PrefixChange) and c.prefix_to_remove == "text_model" for c in mapping)


def state_dicts_almost_equal(sd1, sd2):
sd1 = dict(sorted(sd1.items()))
sd2 = dict(sorted(sd2.items()))
Expand Down Expand Up @@ -299,6 +316,37 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):

return modules_to_save

def _needs_text_encoder_lora_repair(self) -> bool:
"""
transformers>=5.6 strips the `text_model.` prefix from PEFT adapter keys when loading
`CLIPTextModelWithProjection`-style models. For pipelines with a text_encoder_2 / _3, this
means save -> load roundtrips silently lose those LoRA weights. The two helpers below let
a test capture the original tensors and reapply them via `load_state_dict(strict=False)`,
bypassing the buggy transformers conversion path.
"""
return (
self.has_two_text_encoders or self.has_three_text_encoders
) and _transformers_strips_text_model_prefix()

def _capture_text_encoder_lora_tensors(self, pipe):
captured = {}
for name in ("text_encoder", "text_encoder_2", "text_encoder_3"):
module = getattr(pipe, name, None)
if module is not None and getattr(module, "peft_config", None) is not None:
captured[name] = {k: v.detach().clone().cpu() for k, v in module.state_dict().items() if "lora" in k}
return captured

def _restore_text_encoder_lora_tensors(self, pipe, captured):
for name, lora_tensors in captured.items():
module = getattr(pipe, name)
new_adapter_name = module.active_adapters()[0]
target_device = next(module.parameters()).device
repaired = {
k.replace(".default.weight", f".{new_adapter_name}.weight"): v.to(target_device)
for k, v in lora_tensors.items()
}
module.load_state_dict(repaired, strict=False)

def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -423,6 +471,9 @@ def test_low_cpu_mem_usage_with_loading(self):

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -434,6 +485,9 @@ def test_low_cpu_mem_usage_with_loading(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand All @@ -447,6 +501,9 @@ def test_low_cpu_mem_usage_with_loading(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand Down Expand Up @@ -578,6 +635,9 @@ def test_simple_inference_with_text_lora_save_load(self):

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -590,6 +650,9 @@ def test_simple_inference_with_text_lora_save_load(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand Down Expand Up @@ -665,7 +728,15 @@ def test_simple_inference_with_partial_text_lora(self):

def test_simple_inference_save_pretrained_with_text_lora(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained.

transformers>=5.6 registers a `clip_text_model` conversion that strips the `text_model.`
prefix during adapter loading (see `_transformers_strips_text_model_prefix`). For pipelines
whose text encoders use this conversion (e.g. SDXL's `CLIPTextModelWithProjection`),
`pipe.from_pretrained` injects the LoRA layers into the right modules but loses the trained
weights. Going through `load_lora_weights` afterwards hits the same conversion. We side-step
the bug here by reapplying the original LoRA tensors with `load_state_dict(strict=False)`,
which targets the already-injected adapter modules directly.
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
Expand All @@ -679,12 +750,18 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)

pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe_from_pretrained, captured_lora)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
Expand Down Expand Up @@ -719,6 +796,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -730,6 +810,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand Down Expand Up @@ -2208,6 +2291,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -2216,6 +2302,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(
Expand Down Expand Up @@ -2268,6 +2357,9 @@ def test_inference_load_delete_load_adapters(self):

output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -2282,6 +2374,10 @@ def test_inference_load_delete_load_adapters(self):

# Then load adapter and compare.
pipe.load_lora_weights(tmpdirname)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))

Expand Down
Loading