Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/qualcomm/builders/op_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
tensor_shape = list(self.get_tensor(node, node).shape)
out_tensor = torch.full(
node.args[0], node.args[1], dtype=node.meta["val"].dtype
tensor_shape, node.args[1], dtype=node.meta["val"].dtype
)

# since we can derive the constant value of current op in AoT stage
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/op_full_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
in_tensor = node.args[0].meta["val"]
ref_tensor = torch.zeros(in_tensor.shape, dtype=in_tensor.dtype)
in_tensor = self.get_tensor(node, node)
ref_tensor = torch.zeros(list(in_tensor.shape), dtype=in_tensor.dtype)
out_tensor = torch.full_like(ref_tensor, node.args[1])

# since we can derive the constant value of current op in AoT stage
Expand Down
25 changes: 25 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,31 @@ def forward(self, x):
return self.second(self.first(x))


class ConvFull(torch.nn.Module):
def __init__(self, fill, full_shape):
super().__init__()
self.conv = torch.nn.Conv2d(8, 16, 3, padding=1)
self.fill = fill
self.full_shape = full_shape

def forward(self, x):
y = self.conv(x)
c = torch.full(self.full_shape, self.fill, dtype=y.dtype)
return torch.cat([y, c], dim=1)


class ConvFullLike(torch.nn.Module):
def __init__(self, fill):
super().__init__()
self.conv = torch.nn.Conv2d(8, 16, 3, padding=1)
self.fill = fill

def forward(self, x):
y = self.conv(x)
c = torch.full_like(y, self.fill)
return torch.cat([y, c], dim=1)


class ConvTranspose1dSingle(torch.nn.Module):
def __init__(self, bias=True, dilation=1):
super().__init__()
Expand Down
24 changes: 24 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,17 @@ def test_qnn_backend_einsum_outer_product_relu(self):
)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_full_layout_transformed(self):
full_shape = (1, 16, 4, 6)
module = ConvFull(0.5, full_shape) # noqa: F405
sample_input = (torch.randn(1, 8, 4, 6),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_full_like_layout_transformed(self):
module = ConvFullLike(0.5) # noqa: F405
sample_input = (torch.randn(1, 8, 4, 6),)
self.lower_module_and_test_output(module, sample_input)

# TODO: Create a new UT class for passes specific checks
def test_qnn_backend_lift_add_tensor(self):
module = LiftAddTensor() # noqa: F405
Expand Down Expand Up @@ -5095,6 +5106,19 @@ def test_qnn_backend_einsum_outer_product_relu(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_full_layout_transformed(self):
full_shape = (1, 16, 4, 6)
module = ConvFull(0.5, full_shape) # noqa: F405
sample_input = (torch.randn(1, 8, 4, 6),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_full_like_layout_transformed(self):
module = ConvFullLike(0.5) # noqa: F405
sample_input = (torch.randn(1, 8, 4, 6),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

@unittest.skipIf(is_qnn_sdk_version_less_than("2.35"), "UT pass after QNN 2.35")
def test_qnn_backend_masked_softmax(self):
if self.enable_x86_64:
Expand Down
Loading