|
14 | 14 | get_conv1d_output_size, |
15 | 15 | get_conv2d_output_size, |
16 | 16 | get_im2row_output_size, |
| 17 | + is_depthwise_conv, |
17 | 18 | ) |
18 | 19 | from executorch.exir.scalar_type import ScalarType |
19 | 20 | from torch._meta_registrations import _linalg_svd_meta |
@@ -1034,11 +1035,8 @@ def quantized_conv2d_nhwc_meta( |
1034 | 1035 | assert len(in_size) < 6 |
1035 | 1036 |
|
1036 | 1037 | # Determine weight layout based on depthwise vs regular conv. |
1037 | | - # Depthwise is defined by in_channels == groups, where in_channels |
1038 | | - # is the last dim of the NHWC input. |
1039 | 1038 | in_channels = in_size[-1] |
1040 | | - is_depthwise = in_channels == groups |
1041 | | - if is_depthwise: |
| 1039 | + if is_depthwise_conv(groups, in_channels): |
1042 | 1040 | # Depthwise conv: weight is [*kernel_size, OC] |
1043 | 1041 | *kernel_size, out_channels = weight.shape |
1044 | 1042 | else: |
@@ -1177,12 +1175,8 @@ def quantized_conv2d_nhwc_per_tensor_meta( |
1177 | 1175 | assert len(in_size) < 6 |
1178 | 1176 |
|
1179 | 1177 | # Determine weight layout based on depthwise vs regular conv. |
1180 | | - # Depthwise is defined by in_channels == groups, where in_channels |
1181 | | - # is the last dim of the NHWC input. |
1182 | 1178 | in_channels = in_size[-1] |
1183 | | - is_depthwise = in_channels == groups |
1184 | | - if is_depthwise: |
1185 | | - # Depthwise conv: weight is [*kernel_size, OC] |
| 1179 | + if is_depthwise_conv(groups, in_channels): |
1186 | 1180 | *kernel_size, out_channels = weight.shape |
1187 | 1181 | elif len(in_size) == 3: |
1188 | 1182 | # 1D conv: weight is [OC, K, IC] |
@@ -1336,12 +1330,9 @@ def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( |
1336 | 1330 | assert len(in_size) > 2 |
1337 | 1331 | assert len(in_size) < 6 |
1338 | 1332 |
|
1339 | | - # Determine weight layout based on input and weight dimensions: |
1340 | | - # - Depthwise conv: input is 3D/4D, weight is 2/3D [K, OC]/[KH, KW, OC] |
1341 | | - # - 1D conv: input is 3D, weight is 3D [OC, K, IC] |
1342 | | - # - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC] |
1343 | | - if len(weight.shape) == 3: |
1344 | | - # 2D depthwise conv: weight is [KH, KW, OC] |
| 1333 | + # Determine weight layout based on depthwise vs regular conv. |
| 1334 | + in_channels = in_size[-1] |
| 1335 | + if is_depthwise_conv(groups, in_channels): |
1345 | 1336 | *kernel_size, out_channels = weight.shape |
1346 | 1337 | elif len(in_size) == 3: |
1347 | 1338 | # 1D conv: weight is [OC, K, IC] |
@@ -1397,12 +1388,9 @@ def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( |
1397 | 1388 | assert len(in_size) > 2 |
1398 | 1389 | assert len(in_size) < 6 |
1399 | 1390 |
|
1400 | | - # Determine weight layout based on input and weight dimensions: |
1401 | | - # - Depthwise conv: input is 3D/4D, weight is 3D [KH, KW, OC] |
1402 | | - # - 1D conv: input is 3D, weight is 3D [OC, K, IC] |
1403 | | - # - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC] |
1404 | | - if len(weight.shape) == 3: |
1405 | | - # 2D depthwise conv: weight is [KH, KW, OC] |
| 1391 | + # Determine weight layout based on depthwise vs regular conv. |
| 1392 | + in_channels = in_size[-1] |
| 1393 | + if is_depthwise_conv(groups, in_channels): |
1406 | 1394 | *kernel_size, out_channels = weight.shape |
1407 | 1395 | elif len(in_size) == 3: |
1408 | 1396 | # 1D conv: weight is [OC, K, IC] |
|
0 commit comments