Fix activation scale inf issue for const_weight and const_scale#2448
Fix activation scale inf issue for const_weight and const_scale#2448
Conversation
Signed-off-by: Gao, Qun <qun.gao@intel.com>
|
Azure Pipelines: Successfully started running 1 pipeline(s). 8 pipeline(s) were filtered out due to trigger conditions. |
for more information, see https://pre-commit.ci
|
Azure Pipelines: Successfully started running 1 pipeline(s). 8 pipeline(s) were filtered out due to trigger conditions. |
There was a problem hiding this comment.
Pull request overview
Fixes issues when static quantization calibration produces inf activation scales (e.g., too little calibration data) in combination with const_scale / const_weight, aiming to avoid corrupting layer state and producing incorrect outputs.
Changes:
- Add early-return branches in
post_quantization_cleanup()to skip cleanup/const-conversion when a layer is not quantized. - Attempt to remove
input_observerfrom tracked sublayers when quantization is skipped. - Reset
_const_variableswhen quantization is skipped.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not self._is_quantized: | ||
| # Clean up observer only if it exists | ||
| if hasattr(self, "input_observer"): | ||
| if hasattr(self, "_layers") and self.input_observer in self._layers: | ||
| self._layers.remove(self.input_observer) | ||
| # Set call to pass-through/original | ||
| if hasattr(self, "call"): | ||
| # pass through | ||
| pass | ||
| self._const_variables = [] | ||
| self._tracker.lock() | ||
| return |
There was a problem hiding this comment.
The new if not self._is_quantized: ... return short-circuits post_quantization_cleanup() in the normal successful quantization path because _is_quantized is only set to True at the end of this method. This prevents switching self.call to call_symmetric/call_asymmetric and prevents converting const vars, so static quantization will effectively never activate (and deserialized models will still have call() pointing at input_observer, which may not exist). Consider removing this early-return and instead gating the skip-path on a separate flag set by convert() when calibration fails (e.g., _skip_quantization=True), where you also explicitly set call to a pass-through implementation and delete/clear input_observer consistently.
| if not self._is_quantized: | ||
| if hasattr(self, "input_observer"): | ||
| if hasattr(self, "_layers") and self.input_observer in self._layers: | ||
| self._layers.remove(self.input_observer) | ||
| # Set call to pass-through/original | ||
| if hasattr(self, "call"): | ||
| pass | ||
| self._const_variables = [] | ||
| self._tracker.lock() | ||
| return |
There was a problem hiding this comment.
Same issue here: post_quantization_cleanup() now returns immediately whenever _is_quantized is False, but _is_quantized is only set to True at the end of this method (and convert() does not set it on success). This means the cleanup never switches self.call to call_fp8/call_int8 and will break both the static quantization workflow and prepare_deserialized_quantized_model() (which doesn't create input_observer, so leaving call() pointing at the observer path will raise at runtime). Instead of checking _is_quantized here, use a dedicated “skip quantization” flag set when calibration fails, and only take the early-return in that case after restoring the original call path.
| if not self._is_quantized: | ||
| # Clean up observer only if it exists | ||
| if hasattr(self, "input_observer"): | ||
| if hasattr(self, "_layers") and self.input_observer in self._layers: | ||
| self._layers.remove(self.input_observer) | ||
| # Set call to pass-through/original | ||
| if hasattr(self, "call"): | ||
| # pass through | ||
| pass | ||
| self._const_variables = [] | ||
| self._tracker.lock() | ||
| return |
There was a problem hiding this comment.
This change is meant to handle the “activation scale is inf (e.g., single-sample calibration) + const_scale/const_weight” scenario, but there is no regression test covering that failure mode. Since the repo already has JAX quantization pytest coverage (e.g., test/jax/test_save_load.py), please add a unit/integration test that calibrates with a 1-sample dataset that triggers inf scale and asserts the layer/model still runs correctly (no observer-dependent call path, no const var conversion when quantization is skipped).
anko-intel
left a comment
There was a problem hiding this comment.
Added code just disable const_weight and const_scale feature
| pass | ||
| self._const_variables = [] | ||
| self._tracker.lock() | ||
| return |
There was a problem hiding this comment.
Early return just disable converting weight and scale to const (attributes) in line 264 and later
Type of Change
bug fix
Description
A lot of activation scale of inf when using one sample for calibration. For such layer, we should skip when it's marked _not_quantized, but it actually store value, due to constant weight initialization.
Expected Behavior & Potential Risk
resolved constant weight giberrish output issue. No foreseeable risk.
How has this PR been tested?
It has been tested using example_old.py file.
Dependency Change?
No dependency change.