Skip to content

Commit 0ed5508

Browse files
authored
Qualcomm AI Engine Direct - Fix Full (pytorch#19359)
### Summary Based on reported issue: pytorch#19179 2 mathematically equivalent model in nn.Module lowered both lowered to QNN, however, 1 accuracy is good while another is bad. Debugged with qnn_intermediate_debugger and targeted that accuracy drop is possibly caused by `full` operation. Took a look and noticed the op builder did not take layout transform into consideration. <img width="1785" height="1171" alt="image" src="https://github.com/user-attachments/assets/c8ca4cf6-b242-4171-a8ee-3595e0bcd376" /> ### Test plan Passing tests under test_qnn_delegate.py and fix issue in: pytorch#19179
1 parent 2d7ffad commit 0ed5508

4 files changed

Lines changed: 53 additions & 3 deletions

File tree

backends/qualcomm/builders/op_full.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ def define_node(
2525
node: torch.fx.Node,
2626
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
2727
) -> PyQnnManager.PyQnnOpWrapper:
28+
tensor_shape = list(self.get_tensor(node, node).shape)
2829
out_tensor = torch.full(
29-
node.args[0], node.args[1], dtype=node.meta["val"].dtype
30+
tensor_shape, node.args[1], dtype=node.meta["val"].dtype
3031
)
3132

3233
# since we can derive the constant value of current op in AoT stage

backends/qualcomm/builders/op_full_like.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def define_node(
2525
node: torch.fx.Node,
2626
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
2727
) -> PyQnnManager.PyQnnOpWrapper:
28-
in_tensor = node.args[0].meta["val"]
29-
ref_tensor = torch.zeros(in_tensor.shape, dtype=in_tensor.dtype)
28+
in_tensor = self.get_tensor(node, node)
29+
ref_tensor = torch.zeros(list(in_tensor.shape), dtype=in_tensor.dtype)
3030
out_tensor = torch.full_like(ref_tensor, node.args[1])
3131

3232
# since we can derive the constant value of current op in AoT stage

backends/qualcomm/tests/models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,31 @@ def forward(self, x):
875875
return self.second(self.first(x))
876876

877877

878+
class ConvFull(torch.nn.Module):
879+
def __init__(self, fill, full_shape):
880+
super().__init__()
881+
self.conv = torch.nn.Conv2d(8, 16, 3, padding=1)
882+
self.fill = fill
883+
self.full_shape = full_shape
884+
885+
def forward(self, x):
886+
y = self.conv(x)
887+
c = torch.full(self.full_shape, self.fill, dtype=y.dtype)
888+
return torch.cat([y, c], dim=1)
889+
890+
891+
class ConvFullLike(torch.nn.Module):
892+
def __init__(self, fill):
893+
super().__init__()
894+
self.conv = torch.nn.Conv2d(8, 16, 3, padding=1)
895+
self.fill = fill
896+
897+
def forward(self, x):
898+
y = self.conv(x)
899+
c = torch.full_like(y, self.fill)
900+
return torch.cat([y, c], dim=1)
901+
902+
878903
class ConvTranspose1dSingle(torch.nn.Module):
879904
def __init__(self, bias=True, dilation=1):
880905
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,6 +2351,17 @@ def test_qnn_backend_einsum_outer_product_relu(self):
23512351
)
23522352
self.lower_module_and_test_output(module, sample_input)
23532353

2354+
def test_qnn_backend_full_layout_transformed(self):
2355+
full_shape = (1, 16, 4, 6)
2356+
module = ConvFull(0.5, full_shape) # noqa: F405
2357+
sample_input = (torch.randn(1, 8, 4, 6),)
2358+
self.lower_module_and_test_output(module, sample_input)
2359+
2360+
def test_qnn_backend_full_like_layout_transformed(self):
2361+
module = ConvFullLike(0.5) # noqa: F405
2362+
sample_input = (torch.randn(1, 8, 4, 6),)
2363+
self.lower_module_and_test_output(module, sample_input)
2364+
23542365
# TODO: Create a new UT class for passes specific checks
23552366
def test_qnn_backend_lift_add_tensor(self):
23562367
module = LiftAddTensor() # noqa: F405
@@ -5270,6 +5281,19 @@ def test_qnn_backend_einsum_outer_product_relu(self):
52705281
module = self.get_qdq_module(module, sample_input)
52715282
self.lower_module_and_test_output(module, sample_input)
52725283

5284+
def test_qnn_backend_full_layout_transformed(self):
5285+
full_shape = (1, 16, 4, 6)
5286+
module = ConvFull(0.5, full_shape) # noqa: F405
5287+
sample_input = (torch.randn(1, 8, 4, 6),)
5288+
module = self.get_qdq_module(module, sample_input)
5289+
self.lower_module_and_test_output(module, sample_input)
5290+
5291+
def test_qnn_backend_full_like_layout_transformed(self):
5292+
module = ConvFullLike(0.5) # noqa: F405
5293+
sample_input = (torch.randn(1, 8, 4, 6),)
5294+
module = self.get_qdq_module(module, sample_input)
5295+
self.lower_module_and_test_output(module, sample_input)
5296+
52735297
@unittest.skipIf(is_qnn_sdk_version_less_than("2.35"), "UT pass after QNN 2.35")
52745298
def test_qnn_backend_masked_softmax(self):
52755299
if self.enable_x86_64:

0 commit comments

Comments
 (0)