Skip to content

Commit 20a9829

Browse files
authored
FIX Account for rsLoRA scaling in set_scale (#2775)
1 parent 1806c16 commit 20a9829

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

src/peft/tuners/lora/layer.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
113113
self._disable_adapters = False
114114
self.merged_adapters = []
115115
self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC
116+
self.use_rslora: dict[str, bool] = {}
116117
self.lora_bias: dict[str, bool] = {}
117118
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
118119
self._caches: dict[str, Any] = {}
@@ -255,6 +256,8 @@ def update_layer(
255256
else:
256257
self.scaling[adapter_name] = lora_alpha / r
257258

259+
self.use_rslora[adapter_name] = use_rslora
260+
258261
self.use_dora[adapter_name] = use_dora
259262

260263
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
@@ -528,7 +531,10 @@ def set_scale(self, adapter: str, scale: float | int) -> None:
528531
if adapter not in self.scaling:
529532
# Ignore the case where the adapter is not in the layer
530533
return
531-
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
534+
if self.use_rslora.get(adapter, False):
535+
self.scaling[adapter] = scale * self.lora_alpha[adapter] / math.sqrt(self.r[adapter])
536+
else:
537+
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
532538

533539
def scale_layer(self, scale: float | int) -> None:
534540
"""Multiply the current scale of all active adapters by the provided factor"""
@@ -553,9 +559,12 @@ def unscale_layer(self, scale: Optional[float | int] = None) -> None:
553559
continue
554560

555561
if scale is None:
556-
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
562+
if self.use_rslora.get(active_adapter, False):
563+
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / math.sqrt(self.r[active_adapter])
564+
else:
565+
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
557566
else:
558-
self.scaling[active_adapter] /= scale
567+
self.scaling[active_adapter] = self.scaling[active_adapter] / scale
559568

560569
def _check_forward_args(self, x, *args, **kwargs):
561570
"""Check if the arguments are compatible with the configs and state of the model"""
@@ -960,6 +969,8 @@ def update_layer(
960969
else:
961970
self.scaling[adapter_name] = lora_alpha / r
962971

972+
self.use_rslora[adapter_name] = use_rslora
973+
963974
self.use_dora[adapter_name] = use_dora
964975

965976
if init_lora_weights == "loftq":
@@ -1260,6 +1271,8 @@ def update_layer(
12601271
else:
12611272
self.scaling[adapter_name] = lora_alpha / r
12621273

1274+
self.use_rslora[adapter_name] = use_rslora
1275+
12631276
self.use_dora[adapter_name] = use_dora
12641277

12651278
if init_lora_weights == "loftq":
@@ -2033,6 +2046,8 @@ def update_layer(
20332046
else:
20342047
self.scaling[adapter_name] = lora_alpha / r
20352048

2049+
self.use_rslora[adapter_name] = use_rslora
2050+
20362051
self.use_dora[adapter_name] = use_dora
20372052

20382053
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed

tests/test_initialization.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import copy
1616
import itertools
17+
import math
1718
import platform
1819
import re
1920
import warnings
@@ -3891,6 +3892,44 @@ def test_scaling_simple(self, model):
38913892
expected = [2.0] * n_layers
38923893
assert scalings == expected
38933894

3895+
def test_scaling_with_rslora(self, model):
3896+
n_layers = 5
3897+
rank, lora_alpha = 8, 16
3898+
config = LoraConfig(
3899+
r=rank,
3900+
lora_alpha=lora_alpha,
3901+
use_rslora=True,
3902+
target_modules=["k_proj"],
3903+
)
3904+
model = get_peft_model(model, config)
3905+
scalings = self.get_scalings(model)
3906+
expected = [lora_alpha / math.sqrt(rank)] * n_layers
3907+
assert scalings == expected
3908+
3909+
# double
3910+
self.scale_layer(model, 2)
3911+
scalings = self.get_scalings(model)
3912+
expected = [2 * lora_alpha / math.sqrt(rank)] * n_layers
3913+
assert scalings == expected
3914+
3915+
# back to original
3916+
self.unscale_layer(model, None)
3917+
scalings = self.get_scalings(model)
3918+
expected = [lora_alpha / math.sqrt(rank)] * n_layers
3919+
assert scalings == expected
3920+
3921+
# triple
3922+
self.set_scale(model, "default", 3)
3923+
scalings = self.get_scalings(model)
3924+
expected = [3 * lora_alpha / math.sqrt(rank)] * n_layers
3925+
assert scalings == expected
3926+
3927+
# back to original
3928+
self.unscale_layer(model, 3)
3929+
scalings = self.get_scalings(model)
3930+
expected = [lora_alpha / math.sqrt(rank)] * n_layers
3931+
assert scalings == expected
3932+
38943933
def test_scaling_rank_pattern_alpha_pattern(self, model):
38953934
# layer 0: 8 / 8
38963935
# layer 1: 8 / 16

0 commit comments

Comments
 (0)