Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions coremltools/optimize/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,6 @@ def _update_tensor_range(
"""
tensor_min = np.min(np.array(tensor_value).flatten())
tensor_max = np.max(np.array(tensor_value).flatten())
activation_stats_dict[tensor_name]["rmin"] = tensor_min
activation_stats_dict[tensor_name]["rmax"] = tensor_max
if tensor_name in activation_stats_dict:
activation_stats_dict[tensor_name]["rmin"] = min(
tensor_min, activation_stats_dict[tensor_name]["rmin"]
Expand Down
2 changes: 0 additions & 2 deletions coremltools/optimize/coreml/experimental/_model_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,6 @@ def predict_intermediate_outputs(
def record_intermediate_output(output_value, output_name, activation_stats_dict):
tensor_min = np.min(output_value.flatten())
tensor_max = np.max(output_value.flatten())
activation_stats_dict[output_name]["rmin"] = tensor_min
activation_stats_dict[output_name]["rmax"] = tensor_max
if output_name in activation_stats_dict:
activation_stats_dict[output_name]["rmin"] = min(
tensor_min, activation_stats_dict[output_name]["rmin"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def _update_tensor_range(
) -> None:
tensor_min = np.min(np.array(tensor_value).flatten())
tensor_max = np.max(np.array(tensor_value).flatten())
activation_stats_dict[tensor_name]["rmin"] = tensor_min
activation_stats_dict[tensor_name]["rmax"] = tensor_max
if tensor_name in activation_stats_dict:
activation_stats_dict[tensor_name]["rmin"] = min(
tensor_min, activation_stats_dict[tensor_name]["rmin"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
from collections import defaultdict

import numpy as np
import torch
Expand All @@ -13,9 +14,11 @@
_get_activation_calibration_stats,
)
from coremltools.test.optimize.coreml.test_passes import TestCompressionPasses
from coremltools.optimize.coreml.experimental._model_debugger import ModelDebugger
import coremltools.optimize as cto



class TestActivationQuantization:
@staticmethod
def _get_test_mlmodel_conv_relu():
Expand Down Expand Up @@ -167,3 +170,75 @@ def test_get_activation_calibration_stats_concat_surrounding_ops(self):
# Since mlmodel has a concat with 2 inputs and 1 output, we should see at least 3 rmin/rmax pairs are identical in activation_stats.
# If we dedup rmin/rmax pairs with identical values, the length of unique values should at least reduced by 2 compared with original one.
assert len(activation_stats) - len(activation_stats_unique) >= 2

def test_calibration_stats_accumulate_across_samples(self):
"""
Regression test: activation stats from an earlier sample must not be overwritten
by a later sample with a narrower activation range.

The large-input sample is placed FIRST so the bug (unconditionally overwriting
rmin/rmax on every call instead of accumulating the running min/max) would discard
it, leaving only the near-zero activations from the last sample.
"""
mlmodel = self._get_test_mlmodel_conv_relu()

# Large inputs produce activations with a wide range after Conv2d.
# Near-zero inputs produce activations close to zero (ReLU clips negatives).
sample_wide = {"data": np.ones((5, 10, 4, 4), dtype=np.float32) * 100.0}
sample_narrow = {"data": np.zeros((5, 10, 4, 4), dtype=np.float32)}
stats_wide_only = _get_activation_calibration_stats(mlmodel, [sample_wide])
stats_combined = _get_activation_calibration_stats(mlmodel, [sample_wide, sample_narrow])

for key in stats_wide_only:
assert stats_combined[key]["rmax"] >= stats_wide_only[key]["rmax"] - 1e-5, (
f"rmax for '{key}' was overwritten by the later narrow sample: "
f"combined={stats_combined[key]['rmax']:.4f}, "
f"wide-only={stats_wide_only[key]['rmax']:.4f}"
)



class TestRecordIntermediateOutput:
"""Regression tests for ModelDebugger.record_intermediate_output in _model_debugger.py."""

def _make_stats(self):
return defaultdict(dict)

def test_second_call_narrower_does_not_overwrite(self):
stats = self._make_stats()
ModelDebugger.record_intermediate_output(np.array([0.0, 10.0]), "t", stats)
ModelDebugger.record_intermediate_output(np.array([2.0, 5.0]), "t", stats)
assert stats["t"]["rmin"] == 0.0, "rmin was overwritten by a narrower batch"
assert stats["t"]["rmax"] == 10.0, "rmax was overwritten by a narrower batch"

def test_rmax_from_first_call_not_overwritten_by_narrower_second_call(self):
# Call 1 establishes rmax=10.0. Call 2 is narrower ([2, 5]).
# Buggy code overwrites stats with the last call → rmax becomes 5.0.
# Fixed code keeps the running max → rmax stays 10.0.
stats = self._make_stats()
ModelDebugger.record_intermediate_output(np.array([0.0, 10.0]), "t", stats)
ModelDebugger.record_intermediate_output(np.array([2.0, 5.0]), "t", stats)
assert stats["t"]["rmax"] == 10.0, "rmax was overwritten by a narrower second call"

def test_rmin_from_first_call_not_overwritten_by_narrower_second_call(self):
# Call 1 establishes rmin=-1.0. Call 2 is narrower ([2, 5]).
# Buggy code overwrites stats with the last call → rmin becomes 2.0.
# Fixed code keeps the running min → rmin stays -1.0.
stats = self._make_stats()
ModelDebugger.record_intermediate_output(np.array([-1.0, 8.0]), "t", stats)
ModelDebugger.record_intermediate_output(np.array([2.0, 5.0]), "t", stats)
assert stats["t"]["rmin"] == -1.0, "rmin was overwritten by a narrower second call"


def test_many_calls_accumulate_global_extremes(self):
stats = self._make_stats()
batches = [
np.array([3.0, 6.0]),
np.array([1.0, 5.0]),
np.array([4.0, 9.0]),
np.array([2.0, 7.0]),
]
for b in batches:
ModelDebugger.record_intermediate_output(b, "t", stats)
assert stats["t"]["rmin"] == 1.0
assert stats["t"]["rmax"] == 9.0
1 change: 1 addition & 0 deletions coremltools/test/optimize/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
from collections import defaultdict

import itertools

Expand Down