Skip to content

Commit 06f10b9

Browse files
psiddhGithub Executorch
andauthored
[cortex_m] Fix linear weight layout: transpose in AOT pass, align meta/ref impl (pytorch#16782)
### Summary: The linear path in ConvertToCortexMPass was not transposing weights unlike conv2d, causing inconsistency with the C++ runtime which expects weights in [in_features, out_features] format per CMSIS-NN. Changes: - convert_to_cortex_m_pass.py: Transpose linear weights [out, in] -> [in, out] - operators.py: Update meta to use weights.shape[1] for output dimension - operators.py: Remove .T from ref impl (weights pre-transposed by pass) Fixes MV2 output shape mismatch: [1, 1280] -> [1, 1000] MV2 on Corstone-300/E8 with CMSIS-NN kernels This fix ensures the AOT-compiled .pte file has correctly shaped output tensors for any model using quantized_linear (MV2, ResNet, MV3, etc.). ### Test plan Run MV2 lowered to CMSIS-NN ops on E8 alif board Co-authored-by: Github Executorch <github_executorch@arm.com>
1 parent 3ddb86c commit 06f10b9

2 files changed

Lines changed: 24 additions & 7 deletions

File tree

backends/cortex_m/ops/operators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def quantized_linear_meta(
352352
activation_min,
353353
) -> torch.Tensor:
354354

355-
shape = (*input.shape[:-1], weights.shape[0])
355+
shape = (*input.shape[:-1], weights.shape[1])
356356
return torch.empty(shape, dtype=input.dtype, device=input.device)
357357

358358

@@ -386,7 +386,7 @@ def quantized_linear_impl(
386386
input_reshaped = input_int32.reshape(new_shape)
387387

388388
lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset
389-
output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum
389+
output = torch.mm(input_reshaped, weights_int32) + lhs_sum + kernel_sum
390390
output_shape = (*input.shape[:-1], output.shape[-1])
391391
output_reshaped = output.reshape(output_shape)
392392
else:
@@ -396,7 +396,7 @@ def quantized_linear_impl(
396396
new_shape = (prod(input.shape[:-1]), input.shape[-1])
397397
input_reshaped = input_int32.reshape(new_shape)
398398

399-
output = torch.mm(input_reshaped, weights_int32.T)
399+
output = torch.mm(input_reshaped, weights_int32)
400400
if bias is not None:
401401
output = output + bias
402402
output_shape = (*input.shape[:-1], output.shape[-1])

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,19 @@ class ConvertToCortexMPass(XNNPACKPass):
3333
by call_operator.
3434
"""
3535

36-
def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset):
36+
def _compute_kernel_sum(
37+
self, weights_transposed, bias, input_offset, weight_offset
38+
):
3739
"""
3840
Computes the precomputed kernel sum term (bias optional)
3941
a * sum_j(wij + b) + ci
4042
4143
for i = (1, ..., n), where j indexes the input activations.
44+
45+
Args:
46+
weights_transposed: Weights already in [in_features, out_features] format
4247
"""
43-
weights_transposed = weights.T
48+
# No transpose needed - weights already transposed by caller
4449
weights_int32 = weights_transposed.to(torch.int32)
4550
offset_weights = weights_int32 + weight_offset
4651
kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32)
@@ -110,8 +115,12 @@ def _get_linear_replacement(self, node):
110115
if len(node.args) > 2
111116
else None
112117
)
118+
# Transpose weights once from PyTorch format [out_features, in_features]
119+
# to CMSIS-NN format [in_features, out_features]
120+
weights_transposed = weights_tensor.T.contiguous()
121+
# Pass already-transposed weights to kernel_sum computation
113122
kernel_sum_tensor = self._compute_kernel_sum(
114-
weights_tensor, bias_tensor, -input_zp, -weight_zp
123+
weights_transposed, bias_tensor, -input_zp, -weight_zp
115124
)
116125
with node.graph.inserting_after(weights):
117126
kernel_sum = create_constant_placeholder(
@@ -122,9 +131,17 @@ def _get_linear_replacement(self, node):
122131
kernel_sum_tensor,
123132
)
124133

134+
weights_transposed_node = create_constant_placeholder(
135+
self.exported_program,
136+
node.graph,
137+
node.name + "_weights_transposed",
138+
InputKind.PARAMETER,
139+
weights_transposed,
140+
)
141+
125142
args = (
126143
node.args[0],
127-
weights,
144+
weights_transposed_node,
128145
None,
129146
kernel_sum,
130147
-input_zp,

0 commit comments

Comments
 (0)