Skip to content

Commit 0cc005a

Browse files
mcremon-metameta-codesync[bot]
authored andcommitted
Add dedicated HiFi kernel for max pool 2d (#18240)
Summary: Pull Request resolved: #18240 As titled. Calls into nnlib directly. Differential Revision: D96874522 Reviewed By: hsharma35
1 parent bdd6080 commit 0cc005a

8 files changed

Lines changed: 187 additions & 25 deletions

File tree

backends/cadence/aot/ops_registrations.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,27 +2292,27 @@ def quantized_max_pool2d_nchw_meta(
22922292
dilation: list[int],
22932293
ceil_mode: bool,
22942294
) -> torch.Tensor:
2295-
assert len(kernel_size) == 2, f"kernel_size must have 2 elements, got {len(kernel_size)}"
2295+
assert (
2296+
len(kernel_size) == 2
2297+
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
22962298
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
22972299
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
22982300
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
2299-
assert len(input.size()) == 4, f"input must be 4D (N, C, H, W), got {len(input.size())}D"
2301+
assert (
2302+
len(input.size()) == 4
2303+
), f"input must be 4D (N, C, H, W), got {len(input.size())}D"
23002304

23012305
batch = input.size(0)
23022306
channels = input.size(1)
23032307
height_in = input.size(2)
23042308
width_in = input.size(3)
23052309

23062310
height_out_raw = (
2307-
(height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1)
2308-
/ stride[0]
2309-
+ 1
2310-
)
2311+
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
2312+
) / stride[0] + 1
23112313
width_out_raw = (
2312-
(width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1)
2313-
/ stride[1]
2314-
+ 1
2315-
)
2314+
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
2315+
) / stride[1] + 1
23162316

23172317
if ceil_mode:
23182318
height_out = ceil(height_out_raw)
@@ -2333,27 +2333,27 @@ def quantized_max_pool2d_nhwc_meta(
23332333
dilation: list[int],
23342334
ceil_mode: bool,
23352335
) -> torch.Tensor:
2336-
assert len(kernel_size) == 2, f"kernel_size must have 2 elements, got {len(kernel_size)}"
2336+
assert (
2337+
len(kernel_size) == 2
2338+
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
23372339
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
23382340
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
23392341
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
2340-
assert len(input.size()) == 4, f"input must be 4D (N, H, W, C), got {len(input.size())}D"
2342+
assert (
2343+
len(input.size()) == 4
2344+
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"
23412345

23422346
batch = input.size(0)
23432347
height_in = input.size(1)
23442348
width_in = input.size(2)
23452349
channels = input.size(3)
23462350

23472351
height_out_raw = (
2348-
(height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1)
2349-
/ stride[0]
2350-
+ 1
2351-
)
2352+
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
2353+
) / stride[0] + 1
23522354
width_out_raw = (
2353-
(width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1)
2354-
/ stride[1]
2355-
+ 1
2356-
)
2355+
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
2356+
) / stride[1] + 1
23572357

23582358
if ceil_mode:
23592359
height_out = ceil(height_out_raw)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,9 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
739739
dequants_biases,
740740
op_node,
741741
)
742-
elif isinstance(pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern)):
742+
elif isinstance(
743+
pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern)
744+
):
743745
args, kwargs = get_args_and_kwargs_max_pool2d(
744746
inputs_inputs,
745747
op_node,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ def replacement_op(self) -> OpOverload:
503503

504504
# This is a base class for ReLU
505505

506+
506507
# This is a base class for ReLU, since it can be used with two different aten ops
507508
class ReluBasePattern(QuantizationPattern):
508509
@abstractmethod

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,9 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
505505
target=torch.ops.aten.max_pool2d_with_indices.default,
506506
)
507507
self.assertEqual(
508-
len(max_pool_nodes), 1, "Should find exactly one max_pool2d_with_indices node"
508+
len(max_pool_nodes),
509+
1,
510+
"Should find exactly one max_pool2d_with_indices node",
509511
)
510512
return gm, max_pool_nodes[0]
511513

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,9 +2599,7 @@ def test_replace_max_pool2d_nchw_with_nhwc(self) -> None:
25992599
self.assertEqual(
26002600
count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1
26012601
)
2602-
self.assertEqual(
2603-
count_node(gm, exir_ops.edge.aten.permute_copy.default), 0
2604-
)
2602+
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
26052603

26062604
# Deepcopy before the pass
26072605
original = copy.deepcopy(gm)

backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <algorithm>
1212
#include <cstdint>
13+
#include <cstring>
1314
#include <limits>
1415

1516
#include <executorch/backends/cadence/generic/operators/cadence_type_util.h>
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <xa_nnlib_kernels_api.h>
12+
13+
namespace impl {
14+
namespace HiFi {
15+
namespace native {
16+
17+
using ::executorch::aten::IntArrayRef;
18+
using ::executorch::aten::ScalarType;
19+
using ::executorch::aten::Tensor;
20+
using ::executorch::runtime::KernelRuntimeContext;
21+
22+
Tensor& quantized_max_pool2d_nhwc_out(
23+
KernelRuntimeContext& ctx,
24+
const Tensor& input,
25+
IntArrayRef kernel_size,
26+
IntArrayRef stride,
27+
IntArrayRef padding,
28+
IntArrayRef dilation,
29+
bool ceil_mode,
30+
Tensor& output) {
31+
// NHWC layout: [N, H, W, C]
32+
const int32_t batch_size = input.size(0);
33+
const int32_t in_height = input.size(1);
34+
const int32_t in_width = input.size(2);
35+
const int32_t channels = input.size(3);
36+
37+
const int32_t out_height = output.size(1);
38+
const int32_t out_width = output.size(2);
39+
40+
const int32_t kernel_h = kernel_size[0];
41+
const int32_t kernel_w = kernel_size[1];
42+
const int32_t stride_h = stride[0];
43+
const int32_t stride_w = stride[1];
44+
const int32_t pad_h = padding[0];
45+
const int32_t pad_w = padding[1];
46+
47+
// Determine NNLIB precision constants based on dtype
48+
ScalarType dtype = input.scalar_type();
49+
int32_t nnlib_precision;
50+
switch (dtype) {
51+
case ScalarType::Char: // int8
52+
nnlib_precision = PREC_SYM8S;
53+
break;
54+
case ScalarType::Byte: // uint8
55+
nnlib_precision = PREC_ASYM8U;
56+
break;
57+
default:
58+
ET_DCHECK_MSG(
59+
false,
60+
"Unsupported dtype %s for HiFi quantized_max_pool2d_nhwc",
61+
torch::executor::toString(dtype));
62+
return output;
63+
}
64+
65+
// Compute scratch buffer size for NNLIB maxpool
66+
int32_t scratch_size = xa_nn_maxpool_getsize(
67+
channels,
68+
nnlib_precision,
69+
nnlib_precision,
70+
in_height,
71+
in_width,
72+
kernel_h,
73+
kernel_w,
74+
stride_w, // x_stride
75+
stride_h, // y_stride
76+
pad_w, // x_padding
77+
pad_h, // y_padding
78+
out_height,
79+
out_width,
80+
0, // inp_data_format: 0 = NHWC
81+
0); // out_data_format: 0 = NHWC
82+
ET_DCHECK_MSG(scratch_size >= 0, "xa_nn_maxpool_getsize failed");
83+
84+
// Allocate aligned scratch memory
85+
void* p_scratch = kernels::allocate_temp_memory(ctx, scratch_size);
86+
87+
// Process each batch using NNLIB optimized maxpool kernel
88+
for (int32_t n = 0; n < batch_size; ++n) {
89+
const int32_t spatial_size = in_height * in_width * channels;
90+
const int32_t out_spatial_size = out_height * out_width * channels;
91+
92+
int32_t ret;
93+
if (dtype == ScalarType::Char) {
94+
const int8_t* in_batch =
95+
input.const_data_ptr<int8_t>() + n * spatial_size;
96+
int8_t* out_batch =
97+
output.mutable_data_ptr<int8_t>() + n * out_spatial_size;
98+
99+
ret = xa_nn_maxpool_8(
100+
out_batch,
101+
in_batch,
102+
in_height,
103+
in_width,
104+
channels,
105+
kernel_h,
106+
kernel_w,
107+
stride_w, // x_stride
108+
stride_h, // y_stride
109+
pad_w, // x_padding
110+
pad_h, // y_padding
111+
out_height,
112+
out_width,
113+
0, // inp_data_format: NHWC
114+
0, // out_data_format: NHWC
115+
p_scratch);
116+
} else {
117+
const uint8_t* in_batch =
118+
input.const_data_ptr<uint8_t>() + n * spatial_size;
119+
uint8_t* out_batch =
120+
output.mutable_data_ptr<uint8_t>() + n * out_spatial_size;
121+
122+
ret = xa_nn_maxpool_asym8(
123+
out_batch,
124+
in_batch,
125+
in_height,
126+
in_width,
127+
channels,
128+
kernel_h,
129+
kernel_w,
130+
stride_w, // x_stride
131+
stride_h, // y_stride
132+
pad_w, // x_padding
133+
pad_h, // y_padding
134+
out_height,
135+
out_width,
136+
0, // inp_data_format: NHWC
137+
0, // out_data_format: NHWC
138+
p_scratch);
139+
}
140+
ET_DCHECK_MSG(ret == 0, "HiFi xa_nn_maxpool failed");
141+
}
142+
143+
return output;
144+
}
145+
146+
} // namespace native
147+
} // namespace HiFi
148+
} // namespace impl

backends/cadence/hifi/operators/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,16 @@ def define_common_targets():
632632
compatible_with = ["ovr_config//cpu:xtensa"],
633633
)
634634

635+
runtime.cxx_library(
636+
name = "op_quantized_max_pool2d_nhwc",
637+
srcs = ["op_quantized_max_pool2d_nhwc.cpp"],
638+
exported_headers = ["operators.h"],
639+
platforms = CXX,
640+
deps = COMMON_DEPS,
641+
visibility = ["PUBLIC"],
642+
compatible_with = ["ovr_config//cpu:xtensa"],
643+
)
644+
635645
runtime.cxx_library(
636646
name = "op_quantized_relu_asym8s_asym8s_per_tensor_out",
637647
srcs = ["op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp"],

0 commit comments

Comments
 (0)