Fix torch var/std converter ignoring correction > 1 under tracing#2718
Open
adityasingh2400 wants to merge 1 commit into
Open
Fix torch var/std converter ignoring correction > 1 under tracing#2718adityasingh2400 wants to merge 1 commit into
adityasingh2400 wants to merge 1 commit into
Conversation
Under torch.jit.trace, torch.var/std with the correction keyword lower to the same aten::var/aten::std signature as the unbiased keyword: the value lands in the same positional slot. The converter always read that slot as the bool unbiased, so any correction >= 2 was silently treated as unbiased=True (i.e. correction=1), producing a wrong variance/standard deviation. correction 0 and 1 happened to be correct because they coincide with unbiased False/True. Read the argument as correction instead, which subsumes unbiased (unbiased=False == correction 0, unbiased=True == correction 1), and route it through _var(correction=...). Extend test_var_std_with_correction to cover correction=2.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug
When a model is converted via
torch.jit.trace,torch.var/torch.stdcalled with thecorrectionkeyword lower to the sameaten::var/aten::stdsignature as the olderunbiasedkeyword. Both keywords end up in the same positional slot (aten::var(self, dim, <slot>, keepdim)), and the slot's constant is just an int.The converter always read that slot as the boolean
unbiased. So anycorrection >= 2was truthy and silently treated asunbiased=True, i.e.correction=1.correction=0andcorrection=1happened to come out right only because they coincide withunbiased=False/True.Repro (torch 2.12, coremltools from main):
correction=3, 4, ...are all collapsed to thecorrection=1result the same way.stdhas the same problem since it reuses this logic.Fix
Read the argument as
correctionrather thanunbiased.correctionis a strict generalization:unbiased=Falseiscorrection=0andunbiased=Trueiscorrection=1, so existingunbiased-traced graphs keep their current behavior, andcorrection >= 2now divides byN - correctionas PyTorch does. The value is routed through the existing_var(correction=...)path. Theunbiased=kwarg (export) is still accepted and mapped ontocorrection.Test
test_var_std_with_correctiononly parametrizedcorrectionover[0, 1], which is exactly the range the old code got right by accident, so the bug was invisible. Extended it to includecorrection=2. With this fix the full matrix (var/std x correction[0,1,2] x dim[[0,2],[1],[2]] x keepdim) passes; on main thecorrection=2cases fail with a numerical mismatch.