Skip to content

Commit eabe49e

Browse files
authored
Add tolerance to quantizedInputWrapper
Differential Revision: D95822313 Pull Request resolved: pytorch#18019
1 parent 6abcb53 commit eabe49e

1 file changed

Lines changed: 31 additions & 1 deletion

File tree

backends/cadence/aot/compiler_funcs.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,29 +258,42 @@ class QuantizedInputWrapper(torch.nn.Module):
258258
If provided, extracts quant params from graph.
259259
quant_args: Optional dict mapping input index to (scale, zero_point, qmin, qmax, dtype).
260260
If provided, uses these directly instead of extracting from graph.
261+
expected_inputs: Optional dict mapping input index to the expected
262+
dequantized tensor. After dequantization, the result is compared
263+
against these values using atol/rtol. Raises ValueError if exceeded.
264+
atol: Absolute tolerance for the expected-value check (default 1e-4).
265+
rtol: Relative tolerance for the expected-value check (default 1e-4).
261266
262267
Example:
263268
# Extract from graph
264269
wrapper = QuantizedInputWrapper(quantized_module, input_names=["x"])
265270
266-
# Explicit quant args
271+
# Explicit quant args with expected-value validation
267272
wrapper = QuantizedInputWrapper(
268273
quantized_module,
269274
quant_args={0: (1/255, 0, 0, 255, torch.uint8)},
275+
expected_inputs={0: reference_float_tensor},
276+
atol=1e-3,
270277
)
271278
"""
272279

273280
def __init__(
274281
self,
275282
module: GraphModule,
276283
input_args: Optional[Union[list[str], dict[int, QuantArgs]]] = None,
284+
expected_inputs: Optional[dict[int, torch.Tensor]] = None,
285+
atol: float = 1e-4,
286+
rtol: float = 1e-4,
277287
) -> None:
278288
super().__init__()
279289
self.module: GraphModule = module
280290
self.quant_args: dict[int, QuantArgs] = {}
281291
self.expected_shapes: dict[int, tuple[int, ...]] = (
282292
extract_input_shapes_from_graph(module)
283293
)
294+
self.expected_inputs: Optional[dict[int, torch.Tensor]] = expected_inputs
295+
self.atol: float = atol
296+
self.rtol: float = rtol
284297

285298
if input_args is not None:
286299
logger.warning(
@@ -317,6 +330,23 @@ def forward(self, *args: torch.Tensor) -> Any:
317330
)
318331
dequantized_args.append(node)
319332

333+
# Check dequantized values against expected inputs
334+
expected_inputs = self.expected_inputs
335+
if expected_inputs is not None:
336+
for index, expected in expected_inputs.items():
337+
if index >= len(dequantized_args):
338+
continue
339+
actual = dequantized_args[index]
340+
if not torch.allclose(actual, expected, atol=self.atol, rtol=self.rtol):
341+
max_abs_diff = (actual - expected).abs().max().item()
342+
mean_abs_diff = (actual - expected).abs().mean().item()
343+
msg = (
344+
f"Dequantized input at index {index} differs from expected value: "
345+
f"max_abs_diff={max_abs_diff:.6g}, mean_abs_diff={mean_abs_diff:.6g} "
346+
f"(atol={self.atol}, rtol={self.rtol})"
347+
)
348+
raise ValueError(msg)
349+
320350
return self.module(*dequantized_args)
321351

322352

0 commit comments

Comments
 (0)