Skip to content

Commit 6d04ed3

Browse files
authored
Add LoRA support to StaticAttention for split_mha=False (pytorch#18345)
When ModelArgs.target_modules is set, create LoRALinear instead of nn.Linear for targeted q/k/v/o projections. Only applies to split_mha=False path. Existing behavior unchanged when target_modules is None. Authored with Claude.
1 parent 9846a56 commit 6d04ed3

1 file changed

Lines changed: 35 additions & 27 deletions

File tree

examples/models/llama/static_attention.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -812,38 +812,46 @@ def __init__(
812812
[StaticVCache(layer_id, i) for i in range(self.n_kv_heads)]
813813
)
814814
else:
815-
self.wqs = nn.ModuleList(
816-
[
817-
nn.Linear(
818-
self.dim,
819-
self.head_dim * self.n_heads,
820-
bias=self.attention_qkv_bias,
821-
)
822-
]
823-
)
824-
self.wks = nn.ModuleList(
825-
[
826-
nn.Linear(
827-
self.dim,
828-
self.head_dim * self.n_kv_heads,
829-
bias=self.attention_qkv_bias,
830-
)
831-
]
832-
)
833-
self.wvs = nn.ModuleList(
834-
[
835-
nn.Linear(
836-
self.dim,
837-
self.head_dim * self.n_kv_heads,
838-
bias=self.attention_qkv_bias,
815+
has_lora = config.target_modules is not None
816+
_PROJ_TARGET = {
817+
"wqs": ("q_proj", self.dim, self.head_dim * self.n_heads),
818+
"wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads),
819+
"wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads),
820+
}
821+
for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items():
822+
if has_lora and target in config.target_modules:
823+
proj = LoRALinear(
824+
in_dim=in_dim,
825+
out_dim=out_dim,
826+
rank=config.r,
827+
alpha=config.lora_alpha,
828+
use_bias=self.attention_qkv_bias,
839829
)
840-
]
841-
)
830+
else:
831+
proj = nn.Linear(in_dim, out_dim, bias=self.attention_qkv_bias)
832+
setattr(self, attr, nn.ModuleList([proj]))
842833

843834
self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)])
844835
self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)])
845836

846-
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
837+
wo_use_lora = (
838+
not self.split_mha
839+
and config.target_modules is not None
840+
and (
841+
"output_proj" in config.target_modules
842+
or "o_proj" in config.target_modules
843+
)
844+
)
845+
if wo_use_lora:
846+
self.wo = LoRALinear(
847+
in_dim=self.n_heads * self.head_dim,
848+
out_dim=self.dim,
849+
rank=config.r,
850+
alpha=config.lora_alpha,
851+
use_bias=False,
852+
)
853+
else:
854+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
847855
self.rope = _Rope(rope.params)
848856
self.layer_id = layer_id
849857

0 commit comments

Comments
 (0)