Skip to content

Commit e1ecac0

Browse files
authored
Fix generic kernel depthwise NHWC conv and add tests
Differential Revision: D93620973 Pull Request resolved: pytorch#17529
1 parent a5423eb commit e1ecac0

1 file changed

Lines changed: 97 additions & 2 deletions

File tree

backends/cadence/generic/operators/op_quantized_conv2d.cpp

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,101 @@ void quantized_conv2d_nchw(
392392
#undef typed_quantized_conv2d_nchw
393393
}
394394

395+
// Depthwise NHWC convolution.
396+
// Weight layout is [*kernel_size, OC]:
397+
// 2D: [KH, KW, OC] (3D tensor)
398+
// 1D: [K, OC] (2D tensor)
399+
// This differs from regular NHWC conv where weight is [OC, KH, KW, IC].
400+
void quantized_conv2d_nhwc_depthwise(
401+
const Tensor& input,
402+
const Tensor& weight,
403+
const Tensor& bias,
404+
IntArrayRef stride,
405+
IntArrayRef padding,
406+
IntArrayRef dilation,
407+
int16_t groups,
408+
int32_t in_zero_point,
409+
int32_t weight_zero_point,
410+
float bias_scale,
411+
float output_scale,
412+
int32_t output_zero_point,
413+
Tensor& out) {
414+
const bool conv1d = input.dim() == 3;
415+
416+
// input NHWC: [N, H, W, C] or [N, W, C] for 1D
417+
const int n = static_cast<int>(input.size(0));
418+
const int h = static_cast<int>(conv1d ? 1 : input.size(1));
419+
const int w = static_cast<int>(conv1d ? input.size(1) : input.size(2));
420+
const int c = static_cast<int>(conv1d ? input.size(2) : input.size(3));
421+
422+
// Depthwise weight: [KH, KW, OC] or [K, OC] for 1D
423+
const int kh = conv1d ? 1 : static_cast<int>(weight.size(0));
424+
const int kw = conv1d ? static_cast<int>(weight.size(0))
425+
: static_cast<int>(weight.size(1));
426+
const int oc = conv1d ? static_cast<int>(weight.size(1))
427+
: static_cast<int>(weight.size(2));
428+
429+
// output NHWC: [N, OH, OW, OC] or [N, OW, OC] for 1D
430+
const int oh = static_cast<int>(conv1d ? 1 : out.size(1));
431+
const int ow = static_cast<int>(conv1d ? out.size(1) : out.size(2));
432+
433+
const float inv_out_scale = 1.f / output_scale;
434+
435+
// Depthwise: each output channel depends on exactly one input channel.
436+
// ocpg = oc / groups output channels per group.
437+
const int ocpg = oc / groups;
438+
439+
#define typed_quantized_conv2d_nhwc_depthwise(ctype, dtype) \
440+
case ScalarType::dtype: { \
441+
const auto* p_in = input.const_data_ptr<ctype>(); \
442+
const auto* p_weight = weight.const_data_ptr<ctype>(); \
443+
const auto* p_bias = bias.const_data_ptr<int32_t>(); \
444+
auto* p_out = out.mutable_data_ptr<ctype>(); \
445+
for (int _n = 0; _n < n; ++_n) { \
446+
const ctype* in_batch = p_in + _n * h * w * c; \
447+
ctype* out_batch = p_out + _n * oh * ow * oc; \
448+
for (int _oh = 0; _oh < oh; ++_oh) { \
449+
for (int _ow = 0; _ow < ow; ++_ow) { \
450+
ctype* out_pixel = out_batch + (_oh * ow + _ow) * oc; \
451+
for (int _g = 0; _g < groups; ++_g) { \
452+
int soc = _g * ocpg; \
453+
for (int _oc = soc; _oc < soc + ocpg; ++_oc) { \
454+
float acc = p_bias[_oc]; \
455+
for (int _kh = 0; _kh < kh; ++_kh) { \
456+
for (int _kw = 0; _kw < kw; ++_kw) { \
457+
int ih = _oh * stride[0] + _kh * dilation[0] - padding[0]; \
458+
int iw = _ow * stride[1] + _kw * dilation[1] - padding[1]; \
459+
if (ih >= 0 && ih < h && iw >= 0 && iw < w) { \
460+
float lhs = \
461+
in_batch[ih * w * c + iw * c + _g] - in_zero_point; \
462+
float rhs = p_weight[_kh * kw * oc + _kw * oc + _oc] - \
463+
weight_zero_point; \
464+
acc += lhs * rhs; \
465+
} \
466+
} \
467+
} \
468+
float val = bias_scale * acc; \
469+
out_pixel[_oc] = quantize<ctype>( \
470+
val, inv_out_scale, (ctype)output_zero_point); \
471+
} \
472+
} \
473+
} \
474+
} \
475+
} \
476+
break; \
477+
}
478+
479+
ScalarType dtype = out.scalar_type();
480+
switch (dtype) {
481+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc_depthwise);
482+
default:
483+
ET_DCHECK_MSG(
484+
false, "Unhandled dtype %s", torch::executor::toString(dtype));
485+
}
486+
487+
#undef typed_quantized_conv2d_nhwc_depthwise
488+
}
489+
395490
void quantized_conv2d_nhwc(
396491
const Tensor& input,
397492
const Tensor& weight,
@@ -928,7 +1023,7 @@ Tensor& quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out(
9281023
ET_UNUSED int64_t out_multiplier,
9291024
ET_UNUSED int64_t out_shift,
9301025
Tensor& out) {
931-
quantized_conv2d_nhwc(
1026+
quantized_conv2d_nhwc_depthwise(
9321027
input,
9331028
weight,
9341029
bias,
@@ -962,7 +1057,7 @@ Tensor& quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out(
9621057
ET_UNUSED int64_t out_multiplier,
9631058
ET_UNUSED int64_t out_shift,
9641059
Tensor& out) {
965-
quantized_conv2d_nhwc(
1060+
quantized_conv2d_nhwc_depthwise(
9661061
input,
9671062
weight,
9681063
bias,

0 commit comments

Comments
 (0)