Skip to content

Commit 3b7b0eb

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add dedicated HiFi kernel for max pool 2d (#18240)
Summary: As titled. Calls into nnlib directly. Reviewed By: hsharma35 Differential Revision: D96874522
1 parent 1c565e1 commit 3b7b0eb

3 files changed

Lines changed: 159 additions & 0 deletions

File tree

backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <algorithm>
1212
#include <cstdint>
13+
#include <cstring>
1314
#include <limits>
1415

1516
#include <executorch/backends/cadence/generic/operators/cadence_type_util.h>
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <xa_nnlib_kernels_api.h>
12+
13+
namespace impl {
14+
namespace HiFi {
15+
namespace native {
16+
17+
using ::executorch::aten::IntArrayRef;
18+
using ::executorch::aten::ScalarType;
19+
using ::executorch::aten::Tensor;
20+
using ::executorch::runtime::KernelRuntimeContext;
21+
22+
Tensor& quantized_max_pool2d_nhwc_out(
23+
KernelRuntimeContext& ctx,
24+
const Tensor& input,
25+
IntArrayRef kernel_size,
26+
IntArrayRef stride,
27+
IntArrayRef padding,
28+
IntArrayRef dilation,
29+
bool ceil_mode,
30+
Tensor& output) {
31+
// NHWC layout: [N, H, W, C]
32+
const int32_t batch_size = input.size(0);
33+
const int32_t in_height = input.size(1);
34+
const int32_t in_width = input.size(2);
35+
const int32_t channels = input.size(3);
36+
37+
const int32_t out_height = output.size(1);
38+
const int32_t out_width = output.size(2);
39+
40+
const int32_t kernel_h = kernel_size[0];
41+
const int32_t kernel_w = kernel_size[1];
42+
const int32_t stride_h = stride[0];
43+
const int32_t stride_w = stride[1];
44+
const int32_t pad_h = padding[0];
45+
const int32_t pad_w = padding[1];
46+
47+
// Determine NNLIB precision constants based on dtype
48+
ScalarType dtype = input.scalar_type();
49+
int32_t nnlib_precision;
50+
switch (dtype) {
51+
case ScalarType::Char: // int8
52+
nnlib_precision = PREC_SYM8S;
53+
break;
54+
case ScalarType::Byte: // uint8
55+
nnlib_precision = PREC_ASYM8U;
56+
break;
57+
default:
58+
ET_DCHECK_MSG(
59+
false,
60+
"Unsupported dtype %s for HiFi quantized_max_pool2d_nhwc",
61+
torch::executor::toString(dtype));
62+
return output;
63+
}
64+
65+
// Compute scratch buffer size for NNLIB maxpool
66+
int32_t scratch_size = xa_nn_maxpool_getsize(
67+
channels,
68+
nnlib_precision,
69+
nnlib_precision,
70+
in_height,
71+
in_width,
72+
kernel_h,
73+
kernel_w,
74+
stride_w, // x_stride
75+
stride_h, // y_stride
76+
pad_w, // x_padding
77+
pad_h, // y_padding
78+
out_height,
79+
out_width,
80+
0, // inp_data_format: 0 = NHWC
81+
0); // out_data_format: 0 = NHWC
82+
ET_DCHECK_MSG(scratch_size >= 0, "xa_nn_maxpool_getsize failed");
83+
84+
// Allocate aligned scratch memory
85+
void* p_scratch = kernels::allocate_temp_memory(ctx, scratch_size);
86+
87+
// Process each batch using NNLIB optimized maxpool kernel
88+
for (int32_t n = 0; n < batch_size; ++n) {
89+
const int32_t spatial_size = in_height * in_width * channels;
90+
const int32_t out_spatial_size = out_height * out_width * channels;
91+
92+
int32_t ret;
93+
if (dtype == ScalarType::Char) {
94+
const int8_t* in_batch =
95+
input.const_data_ptr<int8_t>() + n * spatial_size;
96+
int8_t* out_batch =
97+
output.mutable_data_ptr<int8_t>() + n * out_spatial_size;
98+
99+
ret = xa_nn_maxpool_8(
100+
out_batch,
101+
in_batch,
102+
in_height,
103+
in_width,
104+
channels,
105+
kernel_h,
106+
kernel_w,
107+
stride_w, // x_stride
108+
stride_h, // y_stride
109+
pad_w, // x_padding
110+
pad_h, // y_padding
111+
out_height,
112+
out_width,
113+
0, // inp_data_format: NHWC
114+
0, // out_data_format: NHWC
115+
p_scratch);
116+
} else {
117+
const uint8_t* in_batch =
118+
input.const_data_ptr<uint8_t>() + n * spatial_size;
119+
uint8_t* out_batch =
120+
output.mutable_data_ptr<uint8_t>() + n * out_spatial_size;
121+
122+
ret = xa_nn_maxpool_asym8(
123+
out_batch,
124+
in_batch,
125+
in_height,
126+
in_width,
127+
channels,
128+
kernel_h,
129+
kernel_w,
130+
stride_w, // x_stride
131+
stride_h, // y_stride
132+
pad_w, // x_padding
133+
pad_h, // y_padding
134+
out_height,
135+
out_width,
136+
0, // inp_data_format: NHWC
137+
0, // out_data_format: NHWC
138+
p_scratch);
139+
}
140+
ET_DCHECK_MSG(ret == 0, "HiFi xa_nn_maxpool failed");
141+
}
142+
143+
return output;
144+
}
145+
146+
} // namespace native
147+
} // namespace HiFi
148+
} // namespace impl

backends/cadence/hifi/operators/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,16 @@ def define_common_targets():
632632
compatible_with = ["ovr_config//cpu:xtensa"],
633633
)
634634

635+
runtime.cxx_library(
636+
name = "op_quantized_max_pool2d_nhwc",
637+
srcs = ["op_quantized_max_pool2d_nhwc.cpp"],
638+
exported_headers = ["operators.h"],
639+
platforms = CXX,
640+
deps = COMMON_DEPS,
641+
visibility = ["PUBLIC"],
642+
compatible_with = ["ovr_config//cpu:xtensa"],
643+
)
644+
635645
runtime.cxx_library(
636646
name = "op_quantized_relu_asym8s_asym8s_per_tensor_out",
637647
srcs = ["op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp"],

0 commit comments

Comments
 (0)