Skip to content

Commit 75c85e7

Browse files
authored
Qualcomm AI Engine Direct - Resolved RMSNorm issue without weight (pytorch#18219)
1 parent b73ca05 commit 75c85e7

4 files changed

Lines changed: 31 additions & 11 deletions

File tree

backends/qualcomm/builders/op_rms_norm.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,25 @@ def define_node(
6161
axes = [node.args[0].meta["val"].dim() - 1]
6262
axes_shape = [len(axes)]
6363

64-
weight_node = self.get_node(node.args[2])
65-
weight_tensor = get_parameter(weight_node, self.edge_program)
64+
has_weight = len(node.args) > 2 and node.args[2] is not None
65+
if has_weight:
66+
weight_node = self.get_node(node.args[2])
67+
weight_tensor = get_parameter(weight_node, self.edge_program)
68+
else:
69+
# elementwise_affine=False: use all-ones weight as identity
70+
weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32)
71+
weight_node = torch.fx.Node(
72+
node.graph,
73+
node.name + "_runtime_weight",
74+
"call_function",
75+
exir_ops.edge.aten.tensor.default,
76+
(),
77+
{},
78+
)
79+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
80+
quant_attrs = quant_attrs.copy()
81+
quant_attrs[QCOM_ZERO_POINT] = 0
82+
weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
6683
weight_tensor_wrapper = self.define_tensor(
6784
weight_node,
6885
node,

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,6 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
12891289
return
12901290

12911291
act_node = node.args[0]
1292-
weight_node = node.args[2]
12931292

12941293
# TODO current only support 16a16w
12951294
annotate_input_qspec_map(
@@ -1298,11 +1297,13 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
12981297
quantization_config.input_activation,
12991298
)
13001299

1301-
annotate_input_qspec_map(
1302-
node,
1303-
weight_node,
1304-
quantization_config.input_activation,
1305-
)
1300+
if len(node.args) > 2 and node.args[2] is not None:
1301+
weight_node = node.args[2]
1302+
annotate_input_qspec_map(
1303+
node,
1304+
weight_node,
1305+
quantization_config.input_activation,
1306+
)
13061307
nodes_to_mark_annotated = [node]
13071308
annotate_output_qspec(node, quantization_config.output_activation)
13081309
_mark_nodes_as_annotated(nodes_to_mark_annotated)

backends/qualcomm/tests/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,11 +1897,11 @@ def forward(self, x):
18971897

18981898

18991899
class RmsNorm(torch.nn.Module):
1900-
def __init__(self, eps=None):
1900+
def __init__(self, eps=None, elementwise_affine=True):
19011901
super().__init__()
1902-
self.rms = torch.nn.RMSNorm([4])
1902+
self.rms = torch.nn.RMSNorm([4], elementwise_affine=elementwise_affine)
19031903
if eps:
1904-
self.rms = torch.nn.RMSNorm([4], eps)
1904+
self.rms = torch.nn.RMSNorm([4], eps, elementwise_affine=elementwise_affine)
19051905

19061906
def forward(self, x):
19071907
return self.rms(x)

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,7 @@ def test_qnn_backend_rms_norm(self):
16301630
modules = [
16311631
RmsNorm(), # noqa: F405
16321632
RmsNorm(eps=1e-5), # noqa: F405
1633+
RmsNorm(elementwise_affine=False), # noqa: F405
16331634
]
16341635
sample_input = (torch.randn([1, 1, 1, 4]),)
16351636
for i, module in enumerate(modules):
@@ -3958,6 +3959,7 @@ def test_qnn_backend_rms_norm(self):
39583959
modules = [
39593960
RmsNorm(), # noqa: F405
39603961
RmsNorm(eps=1e-5), # noqa: F405
3962+
RmsNorm(elementwise_affine=False), # noqa: F405
39613963
]
39623964
sample_input = (torch.randn([1, 1, 1, 4]),)
39633965
for i, module in enumerate(modules):

0 commit comments

Comments
 (0)