@@ -258,29 +258,42 @@ class QuantizedInputWrapper(torch.nn.Module):
258258 If provided, extracts quant params from graph.
259259 quant_args: Optional dict mapping input index to (scale, zero_point, qmin, qmax, dtype).
260260 If provided, uses these directly instead of extracting from graph.
261+ expected_inputs: Optional dict mapping input index to the expected
262+ dequantized tensor. After dequantization, the result is compared
263+ against these values using atol/rtol. Raises ValueError if exceeded.
264+ atol: Absolute tolerance for the expected-value check (default 1e-4).
265+ rtol: Relative tolerance for the expected-value check (default 1e-4).
261266
262267 Example:
263268 # Extract from graph
264269 wrapper = QuantizedInputWrapper(quantized_module, input_names=["x"])
265270
266- # Explicit quant args
271+ # Explicit quant args with expected-value validation
267272 wrapper = QuantizedInputWrapper(
268273 quantized_module,
269274 quant_args={0: (1/255, 0, 0, 255, torch.uint8)},
275+ expected_inputs={0: reference_float_tensor},
276+ atol=1e-3,
270277 )
271278 """
272279
273280 def __init__ (
274281 self ,
275282 module : GraphModule ,
276283 input_args : Optional [Union [list [str ], dict [int , QuantArgs ]]] = None ,
284+ expected_inputs : Optional [dict [int , torch .Tensor ]] = None ,
285+ atol : float = 1e-4 ,
286+ rtol : float = 1e-4 ,
277287 ) -> None :
278288 super ().__init__ ()
279289 self .module : GraphModule = module
280290 self .quant_args : dict [int , QuantArgs ] = {}
281291 self .expected_shapes : dict [int , tuple [int , ...]] = (
282292 extract_input_shapes_from_graph (module )
283293 )
294+ self .expected_inputs : Optional [dict [int , torch .Tensor ]] = expected_inputs
295+ self .atol : float = atol
296+ self .rtol : float = rtol
284297
285298 if input_args is not None :
286299 logger .warning (
@@ -317,6 +330,23 @@ def forward(self, *args: torch.Tensor) -> Any:
317330 )
318331 dequantized_args .append (node )
319332
333+ # Check dequantized values against expected inputs
334+ expected_inputs = self .expected_inputs
335+ if expected_inputs is not None :
336+ for index , expected in expected_inputs .items ():
337+ if index >= len (dequantized_args ):
338+ continue
339+ actual = dequantized_args [index ]
340+ if not torch .allclose (actual , expected , atol = self .atol , rtol = self .rtol ):
341+ max_abs_diff = (actual - expected ).abs ().max ().item ()
342+ mean_abs_diff = (actual - expected ).abs ().mean ().item ()
343+ msg = (
344+ f"Dequantized input at index { index } differs from expected value: "
345+ f"max_abs_diff={ max_abs_diff :.6g} , mean_abs_diff={ mean_abs_diff :.6g} "
346+ f"(atol={ self .atol } , rtol={ self .rtol } )"
347+ )
348+ raise ValueError (msg )
349+
320350 return self .module (* dequantized_args )
321351
322352
0 commit comments