Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions coremltools/converters/mil/frontend/torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,12 +947,14 @@ def _construct_quantization_op(
zero_point: Optional[np.ndarray] = None
if compression_info.zero_point is not None:
zero_point = compression_info.zero_point.detach().numpy()
# For conv/conv_transpose, the weight has rank=4, so we auto-expand scale and zero-point if
# it only has two elements.
if len(weight.shape) == 4 and len(scale.shape) == 2:
scale = np.expand_dims(np.expand_dims(scale, axis=-1), axis=-1)
if zero_point is not None:
zero_point = np.expand_dims(np.expand_dims(zero_point, axis=-1), axis=-1)
# For conv/conv_transpose, the weight has rank >= 4 (rank=4 for Conv2d, rank=5 for Conv3d),
# so we auto-expand scale and zero-point if it only has two elements.
if len(weight.shape) >= 4 and len(scale.shape) == 2:
n_expand = len(weight.shape) - len(scale.shape)
for _ in range(n_expand):
scale = np.expand_dims(scale, axis=-1)
if zero_point is not None:
zero_point = np.expand_dims(zero_point, axis=-1)

if compressed_var is not None and compressed_var.op.op_type == "constexpr_lut_to_dense":
# The quantization on lut could lead to extra two dims at the end.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import pytest
import torch
import torch.nn as nn

ct = pytest.importorskip("coremltools")
import coremltools.test.optimize.torch.conversion.conversion_utils as util
Expand Down Expand Up @@ -60,6 +62,42 @@ def test_linear_quantizer(config, model, mnist_example_input, request):
# endregion


# region LinearQuantizer Conv3d
def test_linear_quantizer_conv3d_w8a8():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails even with your fix.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! The test was failing on macOS because verify_model_outputs hardcodes "input_1" as the predict input key ([conversion_utils.py:107]), but the forward parameter x caused CoreML to name the input "x_1" instead. Fixed by renaming the parameter to input_1 to match the convention used by the rest of the test suite.

"""Regression test for https://github.com/apple/coremltools/issues/2452.

w8a8 symmetric quantization of Conv3d failed with:
ValueError: the `weight` should have same rank as `scale`,
but got (1, 1, 3, 3, 3) vs (1, 1)
because the per-tensor scale was not reshaped to match the 5-D weight rank.
"""

class Conv3dModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv3d(1, 1, kernel_size=3, padding=1)

def forward(self, input_1):
return self.conv1(input_1)

example_input = torch.randn(1, 1, 8, 8, 8)
config = LinearQuantizerConfig.from_dict(
{"global_config": {"quantization_scheme": "symmetric"}}
)
quantizer = LinearQuantizer(Conv3dModel().eval(), config)
quantized_model = get_quantized_model(quantizer, example_input)

util.convert_and_verify(
quantized_model,
example_input,
minimum_deployment_target=ct.target.iOS17,
expected_ops=["constexpr_affine_dequantize"],
)


# endregion


# region GPTQ
@pytest.mark.parametrize(
"config",
Expand Down