Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions coremltools/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,21 +2069,15 @@ def _verify_output_correctness_of_chunks(
pipeline_model: _Optional["_ct.models.MLModel"] = None,
) -> None:
"""Verifies the end-to-end output correctness of full (original) model versus chunked models"""
# lazy import avoids circular error
from coremltools.converters.mil.testing_utils import compute_snr_and_psnr
from coremltools.converters.mil.testing_utils import (
random_gen_input_feature_type as random_gen_input_feature_type,
)

def report_correctness(original_outputs: _np.ndarray, final_outputs: _np.ndarray, log_prefix: str):
""" Report PSNR values across two compatible tensors.
This util is from https://github.com/apple/ml-stable-diffusion/blob/main/python_coreml_stable_diffusion/torch2coreml.py#L80,
with a slightly modification.
"""
ABSOLUTE_MIN_PSNR = 35

_, original_psnr = compute_snr_and_psnr(original_outputs, original_outputs)
_, final_psnr = compute_snr_and_psnr(original_outputs, final_outputs)
_, original_psnr = _compute_snr_and_psnr(original_outputs, original_outputs)
_, final_psnr = _compute_snr_and_psnr(original_outputs, final_outputs)

dB_change = final_psnr - original_psnr
_logger.info(
Expand All @@ -2102,7 +2096,7 @@ def report_correctness(original_outputs: _np.ndarray, final_outputs: _np.ndarray
# Generate inputs for first chunk and full model
input_dict = {}
for input_desc in full_model._spec.description.input:
input_dict[input_desc.name] = random_gen_input_feature_type(input_desc)
input_dict[input_desc.name] = _random_gen_input_feature_type(input_desc)

# Generate outputs for full model
outputs_from_full_model = full_model.predict(input_dict)
Expand Down Expand Up @@ -2140,6 +2134,61 @@ def report_correctness(original_outputs: _np.ndarray, final_outputs: _np.ndarray
)


def _random_gen_input_feature_type(input_desc):
if input_desc.type.WhichOneof("Type") == "multiArrayType":
array_type = input_desc.type.multiArrayType
array_feature_type = _proto.FeatureTypes_pb2.ArrayFeatureType
shape = [s for s in array_type.shape]
if array_type.dataType == array_feature_type.FLOAT32:
dtype = _np.float32
elif array_type.dataType == array_feature_type.INT32:
dtype = _np.int32
elif array_type.dataType == array_feature_type.FLOAT16:
dtype = _np.float16
elif array_type.dataType == array_feature_type.FLOAT64:
dtype = _np.float64
else:
raise ValueError("unsupported type")
return _np.random.rand(*shape).astype(dtype)
elif input_desc.type.WhichOneof("Type") == "imageType":
from PIL import Image as _Image

image_type = input_desc.type.imageType
image_feature_type = _proto.FeatureTypes_pb2.ImageFeatureType
if image_type.colorSpace in (
image_feature_type.BGR,
image_feature_type.RGB,
):
shape = [3, image_type.height, image_type.width]
x = _np.random.randint(low=0, high=256, size=shape)
return _Image.fromarray(_np.transpose(x, [1, 2, 0]).astype(_np.uint8))
elif image_type.colorSpace == image_feature_type.GRAYSCALE:
shape = [image_type.height, image_type.width]
x = _np.random.randint(low=0, high=256, size=shape)
return _Image.fromarray(x.astype(_np.uint8), "L")
elif image_type.colorSpace == image_feature_type.GRAYSCALE_FLOAT16:
shape = (image_type.height, image_type.width)
x = _np.random.rand(*shape)
return _Image.fromarray(x.astype(_np.float32), "F")
else:
raise ValueError("unrecognized image type")
else:
raise ValueError("unsupported type")


def _compute_snr_and_psnr(x, y):
assert len(x) == len(y)
eps = 1e-5
eps2 = 1e-10
noise = x - y
noise_var = _np.sum(noise**2) / len(noise)
signal_energy = _np.sum(y**2) / len(y)
max_signal_energy = _np.amax(y**2)
snr = 10 * _np.log10((signal_energy + eps) / (noise_var + eps2))
psnr = 10 * _np.log10((max_signal_energy + eps) / (noise_var + eps2))
return snr, psnr


def _get_op_idx_split_location(prog: _mil.Program) -> _Tuple[int, int, int]:
"""Find the op that approximately bisects the graph as measure by weights size on each side"""
main_block = prog.functions["main"]
Expand Down
38 changes: 38 additions & 0 deletions coremltools/test/ml_program/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import builtins
import copy
import itertools
import os
Expand Down Expand Up @@ -1502,6 +1503,43 @@ def validate_inference(multifunction_mlpackage_path: str) -> None:
reason="rdar://157488825 (Python 3.13 Unit Test Segmentation Fault)",
)
class TestBisectModel:
@staticmethod
def test_verify_output_correctness_does_not_import_testing_utils(monkeypatch):
class DummyModel:
def __init__(self, inputs, outputs):
spec = proto.Model_pb2.Model()
spec.description.input.extend(inputs)
self._spec = spec
self._outputs = outputs

def predict(self, input_dict):
return self._outputs

input_desc = proto.Model_pb2.FeatureDescription()
input_desc.name = "x"
input_desc.type.multiArrayType.shape.extend([1, 2])
input_desc.type.multiArrayType.dataType = (
proto.FeatureTypes_pb2.ArrayFeatureType.FLOAT32
)

outputs = {"out": np.array([1.0, 2.0], dtype=np.float32)}
full_model = DummyModel([input_desc], outputs)
pipeline_model = DummyModel([input_desc], outputs)

real_import = builtins.__import__

def guarded_import(name, *args, **kwargs):
if name == "coremltools.converters.mil.testing_utils":
raise ModuleNotFoundError("No module named 'pytest'")
return real_import(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", guarded_import)

ct.models.utils._verify_output_correctness_of_chunks(
full_model=full_model,
pipeline_model=pipeline_model,
)

@staticmethod
def check_spec_op_type(model_path, expected_ops):
spec = load_spec(model_path)
Expand Down