Skip to content

Update lora def#18247

Open
lucylq wants to merge 1 commit intomainfrom
lfq.lora-def
Open

Update lora def#18247
lucylq wants to merge 1 commit intomainfrom
lfq.lora-def

Conversation

@lucylq
Copy link
Contributor

@lucylq lucylq commented Mar 17, 2026

Summary

Update lora def to use nn.linear instead of torch.nn.functional.linear

  1. Remove debug string from FQN. Otherwise we get different quant indices for lora and non-lora, after this change, and the foundation weights are no longer shareable...
  2. Having nn.Linear as a submodule allows torchao quant to capture it, and we don't need custom logic filtering LoraLinear. lora_a and lora_b are also captured, which changes our results slightly. Note that only lora_a is quantized (shape [16, dim], where dim %32 ==0), as lora_b has shape [dim, 16], where 16 %32 != 0 and fails the group size check.
  3. Having nn.Linear as submodule also means we need to remap lora weights, as weight names have an extra 'linear' in them.
  4. Add @Property for weight, bias so that it remains BC and can be treated as a regular linear module. This is used in [load_weights_from_attention.py](
    def load_weights_from_attention_mha(

Test plan

CI

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 17, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18247

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Awaiting Approval, 1 New Failure, 25 Pending

As of commit 1ebb7d7 with merge base eb92cec (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 17, 2026
@lucylq lucylq marked this pull request as ready for review March 17, 2026 20:10
Copilot AI review requested due to automatic review settings March 17, 2026 20:10
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the LLaMA example LoRA linear implementation to call an nn.Linear module in forward() (instead of torch.nn.functional.linear), presumably to align with module-based patterns.

Changes:

  • Replace torch.nn.functional.linear(x, self.weight, self.bias) with self.linear(x).
  • Introduce self.linear = nn.Linear(...) during LoRALinear initialization.
Comments suppressed due to low confidence (1)

examples/models/llama/lora.py:36

  • linear and weight are now undefined after switching to self.linear; bias = linear.bias ... and register_parameter("weight", ...) will raise at init time. Either derive bias/weight from self.linear (or remove the extra register_parameter calls entirely) so construction works.
        self.linear = nn.Linear(in_dim, out_dim, bias=use_bias)
        bias = linear.bias if self.use_bias else None
        self.register_parameter("weight", nn.Parameter(weight))
        self.register_parameter(
            "bias", nn.Parameter(bias) if bias is not None else None
        )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Copilot AI review requested due to automatic review settings March 17, 2026 20:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the Llama LoRA implementation to call an nn.Linear submodule directly (instead of torch.nn.functional.linear) and simplifies quantization filtering accordingly.

Changes:

  • Refactors LoRALinear to wrap a real nn.Linear module and use it in forward.
  • Adds partial state-dict backward compatibility by remapping the legacy weight key.
  • Simplifies the 8da*w quantization filter_fn to only consider nn.Linear modules.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
examples/models/llama/source_transformation/quantize.py Simplifies the 8da4w/8da8w quantization filter to target nn.Linear modules only.
examples/models/llama/lora.py Refactors LoRALinear to delegate to an internal nn.Linear and updates forward/state-dict behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@lucylq lucylq requested a review from digantdesai as a code owner March 17, 2026 23:22
Copilot AI review requested due to automatic review settings March 18, 2026 18:13
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the LLaMA LoRA implementation to wrap a real nn.Linear submodule (instead of calling torch.nn.functional.linear directly), enabling TorchAO quantization tooling to recognize and quantize the base linear layer consistently across LoRA and non-LoRA models.

Changes:

  • Refactors LoRALinear to contain self.linear: nn.Linear, adds weight/bias properties for backward-compatible access, and remaps old checkpoint keys on load.
  • Simplifies TorchAO 8da*xw quantization filtering to target nn.Linear modules only (no special-casing LoRA modules).
  • Makes XNNPACK constant serialization keys content-hash-based (removes tensor-name prefix) and updates the LoRA CI expectation string accordingly.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
examples/models/llama/source_transformation/quantize.py Simplifies TorchAO quantization filter to match nn.Linear modules and apply group-size compatibility logic.
examples/models/llama/lora.py Refactors LoRALinear to wrap an nn.Linear submodule; adds BC weight/bias accessors and state-dict key remapping.
backends/xnnpack/operators/node_visitor.py Changes named constant key generation to be solely SHA256-based to stabilize dedup/indexing behavior.
.ci/scripts/test_lora.sh Updates expected output prefix text for the quantized LoRA test case.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines 28 to 33
self.use_bias = use_bias
self.dropout = dropout

linear = nn.Linear(in_dim, out_dim, bias=use_bias)
weight = linear.weight
bias = linear.bias if self.use_bias else None
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)

self.linear = nn.Linear(in_dim, out_dim, bias=use_bias)
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants