Skip to content

Conversation

@iambogeumkim
Copy link

cc @BenjaminBossan

I was delayed in updating the code because I was focusing on company work, but now I'm planning to resume the project in earnest. If I have any questions about implementing the code, may I continue to ask you?

I apologize for opening a new pull request, as the previous one was closed 🥲 Thank you for your understanding.

@iambogeumkim iambogeumkim marked this pull request as draft August 2, 2025 05:45
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for resuming your work on KaSA.

Implementation-wise, we need to take a different approach. Right now, KaSA is just added to the normal LoRA code, but we only want to activate it if the user opts in. Therefore, it should be implemented in a separate class, something like KasaVariant, in peft/tuners/lora/variants.py. Please check how DoRA is implemented and use a similar approach, as I have detailed in my previous comment. If anything is unclear, feel free to ask.

@github-actions
Copy link

github-actions bot commented Sep 1, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

gentle ping @nsbg

@iambogeumkim
Copy link
Author

Thank you for your alert!

I spent some time looking over the KaSA paper and code to get ready for more serious work, but it does seem pretty difficult 🥲 My goal is to upload code that's ready for review before the end of September, so I'm going to try even harder.

Right now, I'm stuck at the 'Extend LoRA variant resolution' stage you mentioned. Honestly, this seems like the most important part, but it's hard for me to figure out where to start—specifically, which file and class I should work on first. Could you help me with this?

@BenjaminBossan
Copy link
Member

That's great to see, thanks for picking this back up.

Right now, I'm stuck at the 'Extend LoRA variant resolution' stage you mentioned. Honestly, this seems like the most important part, but it's hard for me to figure out where to start—specifically, which file and class I should work on first. Could you help me with this?

You're already on the right track, you added KasaLinearVariant, which is the most important step. There are definitely some changes required there, as there is some code that is only relevant for DoRA and can be removed for KaSA. But we can leave that as is for now.

Next about resolving the variants. As a first step, let's revert the changes you made to lora/layer.py and start fresh. We don't need a self.use_kasa attribute, we only have self.use_dora for backwards compatibility, as we didn't have LoRA variants when we first implemented DoRA.

Then let's look at these lines in lora.Linear:

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None
from .variants import DoraLinearVariant
return DoraLinearVariant()

Here we need to extend the functionality to add KaSA. The updated method could be something like:

    def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
        if use_dora and use_kasa:
            raise ValueError("Cannot use DoRA and KaSA at the same time, please choose only one.")

        variant = None
        if use_dora:
            from .variants import DoraLinearVariant

            variant = DoraLinearVariant()
        elif use_kasa:
            ...

        return variant

Does that make sense? Similarly, we'd have to update the resolve_lora_variant methods of other LoRA layers, depending on whether they work with KaSA or not (I'm not sure if KaSA works with Conv2d etc.).

I would suggest that you work on this as a next step, then we'll see what else needs to be done.

@iambogeumkim
Copy link
Author

wow I really appreciate your sincere feedback. I'll read your advice carefully and then move forward 🤗

@iambogeumkim
Copy link
Author

@BenjaminBossan I modified the code in the files below based on what you explained. Please give me feedback if there are parts that still need fixing, and then we can discuss the next steps.

1. variants.py

  • Completed updates to methods in the KasaLinearVariants class

2. layer.py

  • In the LoraLayer class, added self.use_kasa[adapter_name] = use_kasa inside the update_layer method

  • In the Linear class, added KaSA handling logic inside the get_delta_weight method

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for integrating my feedback. I gave this another review and noted the next few changes that are necessary. Please check my comments.

Apart from this, the branch is now encountering merge conflicts. Could you please bring your fork up-to-date with the remote and then merge with, or rebase on, the latest main branch from PEFT? If you have questions on how to resolve the merge conflicts, don't hesitate to ask.

Furthermore, please always run make style on your changes before pushing to make our linter happy.

More of a note for myself: Since KaSA updates the base weights of the model, we will have to take extra care to ensure that it works correctly when saving and loading the adapter.

"""
return None
if use_dora and use_kasa:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's undo the changes in this method body and return None. Instead, since this KaSA layer is implemented for Linear only, add the logic to lora.Linear.resolve_lora_variant instead.

Also, we should update the resolve_lora_variant methods of the other layer types like lora.Embedding.resolve_lora_variant to accept the use_kasa argument but raise an error if it's True. Otherwise, users may add it to non-supported layers and not notice that it doesn't actually do anything there.

Comment on lines 236 to 247
############ kasa #############
self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True)

weight = self.get_base_layer().weight
dtype = weight.dtype
svd_rank = self.in_features - r
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)

#########################
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can be removed, since it's part of KasaLinearVariant.init, right?

# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# SVD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a reference here, so that we know the origin:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
        
# SVD

I put it in here, how is it?

Comment on lines 335 to 348
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = module.get_delta_weight(active_adapter)
return orig_weight + delta_weight

@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
delta_weight = module.get_delta_weight(active_adapter)
orig_weight.data += delta_weight

@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = module.get_delta_weight(active_adapter)
return orig_weight - delta_weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KaSA should have an influence on the merged weights, should it not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this PR is closed, it seems I've incorporated everything else except for this comment (of course, you'd have to look at the code). Could you explain this question in more detail?

x = dropout(x)

# KaSA calculation
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, let's add a reference:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# KaSA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
return result + lora_output

I inserted this near where the actual calculation logic begins, rather than just in an empty space. I think this is a bit better.

@iambogeumkim
Copy link
Author

iambogeumkim commented Sep 16, 2025

@BenjaminBossan oh I didn't mean to close the branch, but it seems to have closed while I was merging with the main branch. I guess I'll have to open a new PR, right? 😰

+) when I tried to sync with the main branch, I ended up discarding all my commits, so did that cause it to close?

@BenjaminBossan
Copy link
Member

oh I didn't mean to close the branch, but it seems to have closed while I was merging with the main branch. I guess I'll have to open a new PR, right? 😰

+) when I tried to sync with the main branch, I ended up discarding all my commits, so did that cause it to close?

I don't know what happened, but I could re-open the PR and there are some changes visible. Can you double check that everything looks as expected? If for some reason it's not what it's expected, you can create a new PR and push your local branch.

@iambogeumkim
Copy link
Author

I usually handle merges in the terminal, and I suspect the pull request was closed because I accidentally wiped the commit history while using the 'Sync fork' feature on GitHub. I'll be more careful in the future. Thanks for reopening it.

I'll review the changes and open a new PR if needed. Sorry to keep bothering you with this.

@BenjaminBossan
Copy link
Member

I'll review the changes and open a new PR if needed. Sorry to keep bothering you with this.

No worries. If the diff on this PR looks good, let me know and I'll do a review. Only open a new PR if for some reason, the code here does not correspond to what it should be.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest updates. There were still a few minor issues, please check. Also some further changes that are needed

    • the resolve_lora_variant from lora.Linear has yet to check if use_kasa and return the KasaLinearVariant() in that case
  1. same here:
    use_dora=lora_config.use_dora,

Note that you can run the tests locally to verify that they pass:

pytest tests/test_custom_models.py -k kasa -v

Once you're done with your changes, don't forget to call make style.

iambogeumkim and others added 7 commits October 11, 2025 15:37
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@iambogeumkim iambogeumkim marked this pull request as ready for review October 11, 2025 10:38
@iambogeumkim
Copy link
Author

I have a question after local testing. I'll briefly share the test results and process.

1. 10 failures occurred

  • 8 dtype mismatch issues, 2 AssertionError

2. dtype issues (all resolved ✅)

  • When testing cases where the model's data type changes, like in the test_forward_bfloat16 method, dtype mismatch issues occurred.
  • This was resolved by adding a part to match the dtype in a method within the KasaLinearVariants class.

3. AssertionError issues

  • An assert False error occurred in the test_disable_adapters method.
    image

  • So, I added the following conditional statement to the forward method so that outputs_before == outputs_disabled holds True in the test_disable_adapters, by skipping KaSA calculation and returning only the base model output when adapters are disabled.

    # Check if adapters are disabled
    if module.disable_adapters:
      return result
  • However, assert False is still occurring.

But now that I think about it, looking at the current KasaLinearVariants class, it applies SVD to the weights and calculates from the init method. So I'm wondering if assert False is actually correct, or if I just tried to solve it in the wrong way.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your recent updates. Your observation is correct, due to KaSA modifying the base weights, it cannot be simply deactivated. To accommodate the tests, I would suggest to write a similar function to this one:

def _skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs):
if (config_cls == LoraConfig) and config_kwargs.get("target_parameters"):
pytest.skip("LoRA with multiple adapters with target_parameters is not supported")

Then, for those tests that don't work with KaSA, let's invoke this skipping logic.

We also have to document this KaSA property. Furthermore, we should consider warning the user if they try to deactivate KaSA, but I'm not sure yet if there is a good place to do that.

Moreover, with the changed base weights, it also means that we cannot use a KaSA adapter together with other adapters (e.g. with a normal LoRA adapter), right? There should be a check for this, which can be implemented in LoraModel:

    def _check_new_adapter_config(self, config: VeraConfig) -> None:
        super()._check_new_adapter_config(config)
        # add a check here that we cannot have multiple adapters if one of them uses KaSA

We then need a new test in test_initialization.py to test this new check.

However, I think it would be possible to have multiple adapters if they all use KaSA, right? The modification to the base weight would be identical for all KaSA adapters, so they should be able to coexist. The issue is that, right now, each new KaSA adapter would re-apply SVD and modify the base weight, which is not good. We would need to have a mechanism to detect that this has already happened and avoid applying it twice. WDYT?

@iambogeumkim
Copy link
Author

Sorry for the delayed update; I've been working on this in bits and pieces due to personal scheduling. Since some time has passed, I'll summarize the feedback you previously gave me and my related work:

1. _skip_test_disable_adapters function

  • Added skip logic for test functions that do not work with KaSA adapters. (you can check in tests/test_custom_models.py)

2. TestKasaInitialization class

  • test_kasa_mixed_adapter_error method: When existing adapters are not KaSA, adding a new KaSA adapter should raise a ValueError.
  • test_kasa_mixed_adapter_error_reverse method: When existing adapters are KaSA, adding a non-KaSA adapter should raise a ValueError.

Both of these test functions passed because they successfully raised a ValueError. My understanding is that "PASSED" here means the ValueError was raised, which is the intended behavior for these functions. Please correct me if my understanding is wrong.

3. Logic to avoid repetitive SVD processes (Just an idea 💡)

I agree with your opinion that repeatedly modifying the base weights for the same KaSA adapter is not ideal. My current idea is to apply a cache, as shown below, and check if SVD has already been applied by checking the existence of this attribute. What do you think of this approach?

if not hasattr(module, "_kasa_svd_cache"):
    # First KaSA adapter: perform SVD and cache the result
    weight = module.get_base_layer().weight
    dtype = weight.dtype
    weight = weight.to(torch.float32)
    U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
    module._kasa_svd_cache = (U, S, Vh)
else:
    # Reuse cached SVD result
    U, S, Vh = module._kasa_svd_cache
    dtype = module.get_base_layer().weight.dtype

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, we're moving towards the goal. I commented on some of the changes, please check.

My current idea is to apply a cache, as shown below, and check if SVD has already been applied by checking the existence of this attribute. What do you think of this approach?

I'm not sure it's needed, but I could be wrong. AFAICT, the SVD is only needed to update the base weight. If we add a second KaSA adapter, the base weight is already updated, so there is no need to cache U, S, Vh, we can just skip completely. Is my understanding correct?

Also, there are some merge conflicts on this PR now. I think they should be easy to resolve, but don't hesitate to ask if you have questions. Finally, before committing, don't forget to call make style.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@iambogeumkim
Copy link
Author

iambogeumkim commented Nov 29, 2025

Check

iambogeumkim and others added 4 commits December 6, 2025 14:34
…apter types, enhancing compatibility checks in the initialization process.
…re SVD is applied only once, while also cleaning up whitespace in multiple locations.
@iambogeumkim
Copy link
Author

@BenjaminBossan

I've addressed the points you mentioned, applied make style, and resolved the conflicts. Let me know if anything else needs to be updated.

Regarding the SVD value caching, I gave it some thought and realized I was stuck on the idea that 'caching is always efficient.' Since the base weights are already updated in the first adapter even when using multiple KaSA adapters, I realized we can simply reuse those values subsequently. So, I modified the code to skip the calculation as you suggested.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new updates. We just merged another LoRA variant, which created merge conflicts with your PR, but it should be easy to resolve. Could you please take care? Thanks.

config1 = LoraConfig(
r=8,
target_modules=["linear"],
init_lora_weights=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this line, as it's irrelevant.

config2 = LoraConfig(
r=16,
target_modules=["linear"],
init_lora_weights=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this line, as it's irrelevant.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# src/peft/tuners/lora/model.py
if len(self.peft_config) > 1:
  kasa_count = sum(1 for cfg in self.peft_config.values() if cfg.use_kasa)
  non_kasa_count = len(self.peft_config) - kasa_count
  
  if kasa_count > 0 and non_kasa_count > 0:
    raise ValueError("KaSA adapters cannot be mixed with other adapter types.")

I understood this to mean that since it's handled in this section, it's irrelevant elsewhere. Is my understanding correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this was a misunderstanding. I meant that the single line I commented on (init_lora_weights=True,) can be removed, the test as a whole is good to keep :) Please restore these tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay haha

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the tests back :) !

iambogeumkim and others added 2 commits December 8, 2025 22:36
…tLoraInitialization, simplifying the test suite and focusing on essential compatibility checks.
@iambogeumkim
Copy link
Author

I applied what you mentioned and resolvd conflicts. Please take a look!

…dapter types in TestLoraInitialization, ensuring compatibility checks are enforced in both configurations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants