Skip to content
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
387ad2a
Add KaSA implementation to layer.py
iambogeumkim May 17, 2025
ae00a34
Add `use_kasa` argument to LoraConfig
iambogeumkim Aug 2, 2025
6588f1a
Add use_kasa parameter to Linear class
iambogeumkim Sep 2, 2025
a824ac9
Add KasaLinearVariant class (just copy of DoraLinearVariant class) in…
iambogeumkim Sep 2, 2025
05e4e07
Add kasa description
iambogeumkim Sep 2, 2025
d1e7e43
Remove unnecessary self.kasa
iambogeumkim Sep 4, 2025
9e53b9f
[WIP] update KasaLinearVariant class with SVD implementation
iambogeumkim Sep 8, 2025
aa37111
Modify merge/unmerge method in KasaLinearVariant class
iambogeumkim Sep 8, 2025
9cfe65c
update KasaLinearVariant class with SVD implementation
iambogeumkim Sep 8, 2025
f9d7cc7
fix type in init method
iambogeumkim Sep 8, 2025
84813a3
delete unnecessary part in layer.py
iambogeumkim Sep 16, 2025
39abcad
add original reference in layer.py
iambogeumkim Sep 16, 2025
06f76d8
merge main to peft-kasa
iambogeumkim Sep 16, 2025
0043ae3
re-add KaSA implementation to variants.py
iambogeumkim Sep 20, 2025
ea59432
add use_kasa param to resolve_lora_variant in other layers
iambogeumkim Sep 27, 2025
574c1b8
delete unnecessary part in layer.py
iambogeumkim Sep 29, 2025
73fa58f
delete unnecessary part in variants.py
iambogeumkim Sep 29, 2025
b9e3190
add _get_delta_weight static method in KasaLinearVariants class
iambogeumkim Oct 4, 2025
2649fdf
update module.get_delta_weight to KasaLinearVariant._get_delta_weight…
iambogeumkim Oct 7, 2025
d06cafd
add kasa test
iambogeumkim Oct 7, 2025
4ee5da1
add dropout
iambogeumkim Oct 8, 2025
0377170
Update tests/test_custom_models.py
iambogeumkim Oct 11, 2025
a536bbf
Update src/peft/tuners/lora/variants.py
iambogeumkim Oct 11, 2025
f8d8057
Update src/peft/tuners/lora/variants.py
iambogeumkim Oct 11, 2025
cd57c7b
Update src/peft/tuners/lora/variants.py
iambogeumkim Oct 11, 2025
2431a2c
add use_kasa param in LoraModel class
iambogeumkim Oct 11, 2025
7276b3b
restore output_tensor variable in Linear class get_delta_weight method
iambogeumkim Oct 11, 2025
5a67b1f
add use_kasa handling condition in resolve_lora_variant method
iambogeumkim Oct 11, 2025
3ec6b18
fix KaSA self, mat1, mat2 dtype error
iambogeumkim Oct 11, 2025
461a89c
fix make style error
iambogeumkim Oct 11, 2025
6ed64c1
add _skip_test_disable_adapters function
iambogeumkim Oct 18, 2025
92061dd
add KaSA compatibility test
iambogeumkim Oct 21, 2025
9a2cf71
update _check_new_adapter_config method with KaSA
iambogeumkim Oct 24, 2025
39cf1f9
Implement tests to ensure KaSA adapters cannot be mixed with other ad…
iambogeumkim Dec 6, 2025
283ff0a
Refactor KaSA adapter compatibility check to simplify logic and impro…
iambogeumkim Dec 6, 2025
bfe8996
Refactor KasaLinearVariant class to improve code readability and ensu…
iambogeumkim Dec 6, 2025
cbc5b0c
Merge branch 'main' into peft-kasa
iambogeumkim Dec 6, 2025
14fa9d7
Merge branch 'main' into peft-kasa
iambogeumkim Dec 8, 2025
b6aae1e
Remove tests for mixing KaSA adapters with other adapter types in Tes…
iambogeumkim Dec 8, 2025
951b6b2
Add tests to validate that KaSA adapters cannot be mixed with other a…
iambogeumkim Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,16 @@ class LoraConfig(PeftConfig):
},
)

use_kasa: bool = field(
default=False,
metadata={
"help": (
"Enable <a href='https://arxiv.org/abs/2412.06071'>'Knowledge-Aware Singular-Value Adaptation of Large Language Models' (KaSA)</a>. This technique leverages "
"singular value decomposition (SVD) with knowledge-aware singular values to dynamically "
"activate parametric knowledge according to its relevance to downstream tasks."
)
}
)
def to_dict(self):
"""
Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
Expand Down
34 changes: 23 additions & 11 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
Expand All @@ -150,6 +150,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_kasa: bool = False,
use_alora: bool = False,
use_qalora: bool = False,
lora_bias: bool = False,
Expand All @@ -175,11 +176,13 @@ def update_layer(

lora_variant = self.resolve_lora_variant(
use_dora=use_dora,
use_kasa=use_kasa,
use_alora=use_alora,
use_qalora=use_qalora,
qalora_group_size=qalora_group_size,
arrow_config=arrow_config,
)

if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -610,6 +613,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_kasa: bool = False,
use_alora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
Expand All @@ -628,27 +632,30 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_kasa=use_kasa,
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, use_kasa: bool, **kwargs
) -> Optional[LoraVariant]:
if arrow_config is not None:
from .variants import ArrowLinearVariant

return ArrowLinearVariant()

if not use_dora and not use_alora:
if not use_dora and not use_alora and not use_kasa:
return None

from .variants import ALoraLinearVariant, DoraLinearVariant
from .variants import ALoraLinearVariant, DoraLinearVariant, KasaLinearVariant

if use_alora:
return ALoraLinearVariant()
elif use_kasa:
return KasaLinearVariant()
else:
return DoraLinearVariant()

Expand Down Expand Up @@ -862,7 +869,7 @@ def __init__(
arrow_config=arrow_config,
)

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -879,6 +886,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora,
use_kasa,
lora_bias,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
Expand All @@ -891,7 +899,8 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_kasa=use_kasa, arrow_config=arrow_config)

if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1147,6 +1156,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_kasa: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
Expand Down Expand Up @@ -1189,6 +1199,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora,
use_kasa,
lora_bias,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
Expand All @@ -1208,7 +1219,7 @@ def update_layer(
PeftWarning,
)

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config, use_kasa=use_kasa)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1452,7 +1463,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv2d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -1469,7 +1480,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv1d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -1486,7 +1497,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv3d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand Down Expand Up @@ -1969,6 +1980,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_kasa: bool = False,
use_qalora: bool = False,
lora_bias: bool = False,
qalora_group_size: int = 32,
Expand All @@ -1985,7 +1997,7 @@ def update_layer(
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size, use_kasa=use_kasa
)
if lora_variant is not None:
raise ValueError(f"lora.{self.__class__.__name__} does not work with LoRA variants like DoRA.")
Expand Down
25 changes: 25 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,29 @@ class LoraModel(BaseTuner):
prefix: str = "lora_"
tuner_layer_cls = LoraLayer
target_module_mapping = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING

def _check_new_adapter_config(self, config: LoraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.

Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.

"""
# TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check
# does not fully correspond to the error message.
if (len(self.peft_config) > 1) and (config.bias != "none"):
raise ValueError(
f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, "
"set bias to 'none' for all adapters."
)
Comment on lines +168 to +172
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this and call super()._check_new_adapter_config(config) instead.


# Check KaSA adapter compatibility (only when adding additional adapters)
if len(self.peft_config) > 1:
kasa_count = sum(1 for cfg in self.peft_config.values() if cfg.use_kasa)
non_kasa_count = len(self.peft_config) - kasa_count

if kasa_count > 0 and non_kasa_count > 0:
raise ValueError("KaSA adapters cannot be mixed with other adapter types.")

def _prepare_model(self, peft_config: LoraConfig, model: nn.Module):
r"""
Expand Down Expand Up @@ -211,6 +234,7 @@ def _create_and_replace(
"use_dora": lora_config.use_dora,
"use_alora": lora_config.alora_invocation_tokens is not None,
"use_qalora": lora_config.use_qalora,
"use_kasa": lora_config.use_kasa,
"qalora_group_size": lora_config.qalora_group_size,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"lora_bias": lora_config.lora_bias,
Expand Down Expand Up @@ -248,6 +272,7 @@ def _create_and_replace(
init_lora_weights=lora_config.init_lora_weights,
use_rslora=lora_config.use_rslora,
use_dora=lora_config.use_dora,
use_kasa=lora_config.use_kasa,
lora_bias=lora_config.lora_bias,
arrow_config=lora_config.arrow_config,
inference_mode=lora_config.inference_mode,
Expand Down
110 changes: 110 additions & 0 deletions src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,116 @@ def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None:
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)


class KasaLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
if not hasattr(module, "lora_diag"):
module.lora_diag = nn.ParameterDict()
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_diag",)

# Initialize lora_diag with the same dtype as the base layer
base_dtype = module.get_base_layer().weight.dtype
module.lora_diag[adapter_name] = nn.Parameter(
torch.randn(module.r[adapter_name], dtype=base_dtype), requires_grad=True
)

# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
if not getattr(module, "_kasa_svd_applied", False):
weight = module.get_base_layer().weight
dtype = weight.dtype
svd_rank = module.in_features - module.r[adapter_name]
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
reconstructed_weight = U_principle @ torch.diag(S_principle) @ Vh_principle
module.get_base_layer().weight.data = reconstructed_weight.to(dtype)
module._kasa_svd_applied = True

@staticmethod
def _get_delta_weight(weight_A, weight_B, lora_diag, scaling, fan_in_fan_out):
# Ensure all tensors have the same dtype
target_dtype = weight_A.dtype
weight_B = weight_B.to(target_dtype)
lora_diag = lora_diag.to(target_dtype)

diag = torch.diag(lora_diag)
delta = weight_B @ diag @ weight_A
if fan_in_fan_out:
delta = delta.transpose(0, 1)
delta = delta * scaling
return delta

@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out,
)
return orig_weight + delta_weight

@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out,
)
orig_weight.data += delta_weight

@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out,
)
return orig_weight - delta_weight

@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor, **kwargs) -> torch.Tensor:
# Check if adapters are disabled
if module.disable_adapters:
return result

lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
diag = torch.diag(module.lora_diag[active_adapter])

# KaSA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110

# Ensure all tensors have the same dtype as the result
target_dtype = result.dtype
x = x.to(target_dtype)
diag = diag.to(target_dtype)

# Convert LoRA weights to target dtype
lora_A.weight.data = lora_A.weight.data.to(target_dtype)
lora_B.weight.data = lora_B.weight.data.to(target_dtype)

lora_A_output = lora_A(dropout(x))

if x.ndim == 3:
einsum_output = torch.einsum("ijk,kl->ijl", lora_A_output, diag)
lora_output = lora_B(einsum_output) * scaling
elif x.ndim == 2:
matmul_output = lora_A_output @ diag
lora_output = lora_B(matmul_output) * scaling
else:
raise ValueError(f"Using KaSA with inputs of shape {x.ndim} is not supported, only 2 or 3 dims.")

return result + lora_output


class QALoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@
LoraConfig,
{"target_modules": ["lin0"], "target_parameters": ["lin1.weight"]},
),
("Vanilla MLP 7 LoRA with KaSA", "MLP", LoraConfig, {"target_modules": ["lin0"], "use_kasa": True}),
("Vanilla MLP 8 LoRA with KaSA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_kasa": True}),
#######
# IA³ #
#######
Expand Down Expand Up @@ -1231,6 +1233,9 @@ def _skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config
if (config_cls == LoraConfig) and config_kwargs.get("target_parameters"):
pytest.skip("LoRA with multiple adapters with target_parameters is not supported")

def _skip_test_disable_adapters(config_cls, config_kwargs):
if (config_cls == LoraConfig) and config_kwargs.get("use_kasa"):
pytest.skip("KaSA modifies base weights, so adapter disable test is skipped")

class MLP(nn.Module):
def __init__(self, bias=True):
Expand Down Expand Up @@ -2138,6 +2143,7 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
# Test that it's possible to disable the adapter, in which case the model output should be identical to that of
# the base model.
_skip_test_disable_adapters(config_cls, config_kwargs)
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval()
outputs_base = model(**X)
Expand Down
Loading