@@ -392,6 +392,101 @@ void quantized_conv2d_nchw(
392392#undef typed_quantized_conv2d_nchw
393393}
394394
395+ // Depthwise NHWC convolution.
396+ // Weight layout is [*kernel_size, OC]:
397+ // 2D: [KH, KW, OC] (3D tensor)
398+ // 1D: [K, OC] (2D tensor)
399+ // This differs from regular NHWC conv where weight is [OC, KH, KW, IC].
400+ void quantized_conv2d_nhwc_depthwise (
401+ const Tensor& input,
402+ const Tensor& weight,
403+ const Tensor& bias,
404+ IntArrayRef stride,
405+ IntArrayRef padding,
406+ IntArrayRef dilation,
407+ int16_t groups,
408+ int32_t in_zero_point,
409+ int32_t weight_zero_point,
410+ float bias_scale,
411+ float output_scale,
412+ int32_t output_zero_point,
413+ Tensor& out) {
414+ const bool conv1d = input.dim () == 3 ;
415+
416+ // input NHWC: [N, H, W, C] or [N, W, C] for 1D
417+ const int n = static_cast <int >(input.size (0 ));
418+ const int h = static_cast <int >(conv1d ? 1 : input.size (1 ));
419+ const int w = static_cast <int >(conv1d ? input.size (1 ) : input.size (2 ));
420+ const int c = static_cast <int >(conv1d ? input.size (2 ) : input.size (3 ));
421+
422+ // Depthwise weight: [KH, KW, OC] or [K, OC] for 1D
423+ const int kh = conv1d ? 1 : static_cast <int >(weight.size (0 ));
424+ const int kw = conv1d ? static_cast <int >(weight.size (0 ))
425+ : static_cast <int >(weight.size (1 ));
426+ const int oc = conv1d ? static_cast <int >(weight.size (1 ))
427+ : static_cast <int >(weight.size (2 ));
428+
429+ // output NHWC: [N, OH, OW, OC] or [N, OW, OC] for 1D
430+ const int oh = static_cast <int >(conv1d ? 1 : out.size (1 ));
431+ const int ow = static_cast <int >(conv1d ? out.size (1 ) : out.size (2 ));
432+
433+ const float inv_out_scale = 1 .f / output_scale;
434+
435+ // Depthwise: each output channel depends on exactly one input channel.
436+ // ocpg = oc / groups output channels per group.
437+ const int ocpg = oc / groups;
438+
439+ #define typed_quantized_conv2d_nhwc_depthwise (ctype, dtype ) \
440+ case ScalarType::dtype: { \
441+ const auto * p_in = input.const_data_ptr <ctype>(); \
442+ const auto * p_weight = weight.const_data_ptr <ctype>(); \
443+ const auto * p_bias = bias.const_data_ptr <int32_t >(); \
444+ auto * p_out = out.mutable_data_ptr <ctype>(); \
445+ for (int _n = 0 ; _n < n; ++_n) { \
446+ const ctype* in_batch = p_in + _n * h * w * c; \
447+ ctype* out_batch = p_out + _n * oh * ow * oc; \
448+ for (int _oh = 0 ; _oh < oh; ++_oh) { \
449+ for (int _ow = 0 ; _ow < ow; ++_ow) { \
450+ ctype* out_pixel = out_batch + (_oh * ow + _ow) * oc; \
451+ for (int _g = 0 ; _g < groups; ++_g) { \
452+ int soc = _g * ocpg; \
453+ for (int _oc = soc; _oc < soc + ocpg; ++_oc) { \
454+ float acc = p_bias[_oc]; \
455+ for (int _kh = 0 ; _kh < kh; ++_kh) { \
456+ for (int _kw = 0 ; _kw < kw; ++_kw) { \
457+ int ih = _oh * stride[0 ] + _kh * dilation[0 ] - padding[0 ]; \
458+ int iw = _ow * stride[1 ] + _kw * dilation[1 ] - padding[1 ]; \
459+ if (ih >= 0 && ih < h && iw >= 0 && iw < w) { \
460+ float lhs = \
461+ in_batch[ih * w * c + iw * c + _g] - in_zero_point; \
462+ float rhs = p_weight[_kh * kw * oc + _kw * oc + _oc] - \
463+ weight_zero_point; \
464+ acc += lhs * rhs; \
465+ } \
466+ } \
467+ } \
468+ float val = bias_scale * acc; \
469+ out_pixel[_oc] = quantize<ctype>( \
470+ val, inv_out_scale, (ctype)output_zero_point); \
471+ } \
472+ } \
473+ } \
474+ } \
475+ } \
476+ break ; \
477+ }
478+
479+ ScalarType dtype = out.scalar_type ();
480+ switch (dtype) {
481+ ET_FORALL_CADENCE_QUANTIZED_TYPES (typed_quantized_conv2d_nhwc_depthwise);
482+ default :
483+ ET_DCHECK_MSG (
484+ false , " Unhandled dtype %s" , torch::executor::toString (dtype));
485+ }
486+
487+ #undef typed_quantized_conv2d_nhwc_depthwise
488+ }
489+
395490void quantized_conv2d_nhwc (
396491 const Tensor& input,
397492 const Tensor& weight,
@@ -928,7 +1023,7 @@ Tensor& quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out(
9281023 ET_UNUSED int64_t out_multiplier,
9291024 ET_UNUSED int64_t out_shift,
9301025 Tensor& out) {
931- quantized_conv2d_nhwc (
1026+ quantized_conv2d_nhwc_depthwise (
9321027 input,
9331028 weight,
9341029 bias,
@@ -962,7 +1057,7 @@ Tensor& quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out(
9621057 ET_UNUSED int64_t out_multiplier,
9631058 ET_UNUSED int64_t out_shift,
9641059 Tensor& out) {
965- quantized_conv2d_nhwc (
1060+ quantized_conv2d_nhwc_depthwise (
9661061 input,
9671062 weight,
9681063 bias,
0 commit comments