Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,24 @@ def _insert_image_preprocessing_ops(block):
last_output = input_var
input_nptype = nptype_from_builtin(type(last_output.dtype()))
if input_type.scale != 1:
scale_arr = np.array(input_type.scale, dtype=input_nptype)
# For per-channel scale (a list) on RGB/BGR images, the array
# has shape `(3,)` and would otherwise try to broadcast against
# the last axis of the channel-first input. Reshape to the same
# broadcast layout used by `bias` below.
if scale_arr.ndim > 0 and input_type.color_layout not in (
_input_types.ColorLayout.GRAYSCALE,
_input_types.ColorLayout.GRAYSCALE_FLOAT16,
):
if len(last_output.shape) == 3:
scale_arr = scale_arr.reshape([3, 1, 1])
elif len(last_output.shape) == 4:
scale_arr = scale_arr.reshape([1, 3, 1, 1])
else:
raise TypeError("Unsupported rank for image input type.")
last_output = mb.mul(
x=last_output,
y=np.array(input_type.scale, dtype=input_nptype),
y=scale_arr,
name=input_var.name + "__scaled__",
)
if has_bias:
Expand Down
84 changes: 84 additions & 0 deletions coremltools/converters/mil/backend/mil/passes/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,90 @@ def prog(x):
add_op = prog.find_ops(op_type="add", exactly_one=False)[0]
assert add_op.y.dtype() == prog.functions["main"].inputs["x"].dtype()

def test_program_rgb_per_channel_scale(self):
"""
Regression test for https://github.com/apple/coremltools/issues/2461:
when `ImageType.scale` is a per-channel list (matching the documented
``scale: float or list of floats`` API), the preprocessing pass used
to fail with::

ValueError: Incompatible dim 3 in shapes (1, 3, 224, 224) vs. (1, 1, 1, 3)

because the scale constant kept its `(3,)` shape and broadcast against
the last axis of the channel-first input. The fix reshapes the scale
to the same `(1, 3, 1, 1)` layout already used by per-channel `bias`.
"""

@mb.program(input_specs=[mb.TensorSpec(shape=(1, 3, 20, 20))])
def prog(x):
y1 = mb.relu(x=x)
y2 = mb.relu(x=x)
z = mb.add(x=y1, y=y2)
return z

prog.functions["main"].input_types = (
ct.ImageType(
name="x",
shape=[1, 3, 20, 20],
scale=[1 / 127.5, 1 / 127.5, 1 / 127.5],
bias=[-1.0, -1.0, -1.0],
color_layout="RGB",
channel_first=True,
),
)

prev_prog, prev_block, block = apply_pass_and_basic_check(
prog, "mil_backend::insert_image_preprocessing_ops"
)
assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"]
assert get_op_types_in_program(prog) == ["mul", "add", "relu", "relu", "add"]

scale_op = prog.find_ops(op_type="mul", exactly_one=True)[0]
np.testing.assert_allclose(
scale_op.y.val,
np.array([1 / 127.5, 1 / 127.5, 1 / 127.5]).reshape([1, 3, 1, 1]),
)

add_op = prog.find_ops(op_type="add", exactly_one=False)[0]
np.testing.assert_allclose(
add_op.y.val, np.array([-1.0, -1.0, -1.0]).reshape([1, 3, 1, 1])
)

def test_program_rgb_per_channel_scale_rank3(self):
"""Rank-3 sibling of `test_program_rgb_per_channel_scale`: per-channel
scale must reshape to `(3, 1, 1)` for a `(3, H, W)` channel-first
input, mirroring the existing rank-3 bias path."""

@mb.program(input_specs=[mb.TensorSpec(shape=(3, 20, 20))])
def prog(x):
y = mb.relu(x=x)
return y

prog.functions["main"].input_types = (
ct.ImageType(
name="x",
shape=[3, 20, 20],
scale=[0.5, 0.25, 0.125],
bias=[1.0, 2.0, 3.0],
color_layout="RGB",
channel_first=True,
),
)

prev_prog, prev_block, block = apply_pass_and_basic_check(
prog, "mil_backend::insert_image_preprocessing_ops"
)
assert get_op_types_in_program(prog) == ["mul", "add", "relu"]
scale_op = prog.find_ops(op_type="mul", exactly_one=True)[0]
np.testing.assert_allclose(
scale_op.y.val, np.array([0.5, 0.25, 0.125]).reshape([3, 1, 1])
)
add_op = prog.find_ops(op_type="add", exactly_one=True)[0]
np.testing.assert_allclose(
add_op.y.val, np.array([1.0, 2.0, 3.0]).reshape([3, 1, 1])
)


class TestSanitizerPass:

def test_sanitize_numeric_var_names(self):
Expand Down