Skip to content

Commit 45fd1f3

Browse files
committed
Update lora def
1 parent eb92cec commit 45fd1f3

4 files changed

Lines changed: 26 additions & 37 deletions

File tree

.ci/scripts/test_lora.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start
138138
Okay, so I need to calculate 15% of 80."
139139
EXPECTED_QUANT_LORA_PREFIX="
140140
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
141-
To calculate 15% of 80, we can multiply 80 by 15/100.
142-
80 * 15/100 = 12.
143-
So, 15% of 80 is 12.
141+
15% of 80 is equal to (15/100) * 80 = 12. So, 15% of 80 is 12.
144142
#### 12
145143
The answer is: 12<|im_end|>"
146144

backends/xnnpack/operators/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def get_serialized_buffer_index(
625625
f"Serializing constant data node {tensor} but tensor value has no bytes",
626626
)
627627
sha256_hash = hashlib.sha256(bytes(array))
628-
named_key = tensor.name + "_" + sha256_hash.hexdigest()
628+
named_key = sha256_hash.hexdigest()
629629

630630
size = const_val.untyped_storage().nbytes()
631631
xnn_graph.constant_data.append(

examples/models/llama/lora.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,30 @@ def __init__(
2828
self.use_bias = use_bias
2929
self.dropout = dropout
3030

31-
linear = nn.Linear(in_dim, out_dim, bias=use_bias)
32-
weight = linear.weight
33-
bias = linear.bias if self.use_bias else None
34-
self.register_parameter("weight", nn.Parameter(weight))
35-
self.register_parameter(
36-
"bias", nn.Parameter(bias) if bias is not None else None
37-
)
38-
31+
self.linear = nn.Linear(in_dim, out_dim, bias=use_bias)
3932
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
4033
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
4134
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
4235

36+
@property
37+
def weight(self):
38+
return self.linear.weight
39+
40+
@property
41+
def bias(self):
42+
return self.linear.bias
43+
44+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
45+
# Remap old-style keys to "linear.*" for backward compat
46+
for attr in ("weight", "bias"):
47+
old_key = prefix + attr
48+
new_key = prefix + "linear." + attr
49+
if old_key in state_dict and new_key not in state_dict:
50+
state_dict[new_key] = state_dict.pop(old_key)
51+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
52+
4353
def forward(self, x: torch.Tensor) -> torch.Tensor:
44-
out = torch.nn.functional.linear(x, self.weight, self.bias)
54+
out = self.linear(x)
4555
lora_out = self.lora_a(self.dropout(x))
4656
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
4757

examples/models/llama/source_transformation/quantize.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -144,30 +144,11 @@ def quantize( # noqa C901
144144
from torchao.utils import unwrap_tensor_subclass
145145

146146
def filter_fn(m, fqn):
147-
# Check if it's a regular nn.Linear
148-
is_linear = isinstance(m, nn.Linear)
149-
150-
# Check if it's a LoRALinear (which has a base weight parameter to quantize)
151-
is_lora_linear = False
152-
try:
153-
from executorch.examples.models.llama.lora import LoRALinear
154-
155-
is_lora_linear = isinstance(m, LoRALinear)
156-
except ImportError:
157-
pass
158-
159-
# Check if the weight shape is compatible with group size
160-
has_shape_compatible_with_group_size = False
161-
if is_linear or is_lora_linear:
162-
if group_size == 0:
163-
has_shape_compatible_with_group_size = True
164-
else:
165-
has_shape_compatible_with_group_size = (
166-
m.weight.shape[1] % group_size == 0
167-
)
168-
return (
169-
is_linear or is_lora_linear
170-
) and has_shape_compatible_with_group_size
147+
if not isinstance(m, nn.Linear):
148+
return False
149+
if group_size == 0:
150+
return True
151+
return m.weight.shape[1] % group_size == 0
171152

172153
weight_dtype = torch.int4 if qmode == "8da4w" else torch.int8
173154
quantize_(

0 commit comments

Comments
 (0)