Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions backends/cadence/generic/operators/op_quantized_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,13 @@ void quantized_conv2d_nhwc(
const int c = static_cast<int>(conv1d ? input.size(2) : input.size(3));
// Depthwise is defined by in_channels == groups; depthwise weights have one
// fewer dim than regular weights because the IC dim (always 1) was squeezed.
const bool is_depthwise =
!conv1d && c == groups && weight.dim() < input.dim();
const bool is_depthwise = c == groups && weight.dim() < input.dim();
int oc, wh, ww, wc;
if (is_depthwise) {
// Depthwise weight is [KH, KW, OC]
wh = static_cast<int>(weight.size(0));
ww = static_cast<int>(weight.size(1));
oc = static_cast<int>(weight.size(2));
// Depthwise weight: conv2d=[KH, KW, OC], conv1d=[K, OC]
wh = static_cast<int>(conv1d ? 1 : weight.size(0));
ww = static_cast<int>(conv1d ? weight.size(0) : weight.size(1));
oc = static_cast<int>(conv1d ? weight.size(1) : weight.size(2));
wc = 1;
} else {
// Regular weight is [OC, WH, WW, WC] or for conv1d [OC, WW, WC]
Expand Down
22 changes: 10 additions & 12 deletions backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,16 @@ void xa_opt_quantized_conv2d_nhwc(
// Depthwise is defined by in_channels == groups; depthwise weights have one
// fewer dim than regular weights because the IC dim (always 1) was
// squeezed.
bool is_depthwise =
!conv1d && input_channels == groups && weight.dim() < input.dim();
bool is_depthwise = input_channels == groups && weight.dim() < input.dim();
WORD32 kernel_height;
WORD32 kernel_width;
WORD32 kernel_channels;
WORD32 out_channels;
if (is_depthwise) {
// Depthwise weight is [KH, KW, OC]
kernel_height = weight.size(0);
kernel_width = weight.size(1);
out_channels = weight.size(2);
// Depthwise weight: conv2d=[KH, KW, OC], conv1d=[K, OC]
kernel_height = conv1d ? 1 : weight.size(0);
kernel_width = conv1d ? weight.size(0) : weight.size(1);
out_channels = conv1d ? weight.size(1) : weight.size(2);
kernel_channels = 1;
} else {
// Regular weight is [OC, IC, KH, KW] or for conv1d [OC, K, IC]
Expand Down Expand Up @@ -384,14 +383,13 @@ void quantized_conv2d_nhwc(
const int c = conv1d ? input.size(2) : input.size(3);
// Depthwise is defined by in_channels == groups; depthwise weights have one
// fewer dim than regular weights because the IC dim (always 1) was squeezed.
const bool is_depthwise =
!conv1d && c == groups && weight.dim() < input.dim();
const bool is_depthwise = c == groups && weight.dim() < input.dim();
int oc, wh, ww, wc;
if (is_depthwise) {
// Depthwise weight is [KH, KW, OC]
wh = weight.size(0);
ww = weight.size(1);
oc = weight.size(2);
// Depthwise weight: conv2d=[KH, KW, OC], conv1d=[K, OC]
wh = conv1d ? 1 : weight.size(0);
ww = conv1d ? weight.size(0) : weight.size(1);
oc = conv1d ? weight.size(1) : weight.size(2);
wc = 1;
} else {
// Regular weight is [OC, WH, WW, WC] or for conv1d [OC, WW, WC]
Expand Down
Loading