Skip to content

Fix torch var/std converter ignoring correction > 1 under tracing#2718

Open
adityasingh2400 wants to merge 1 commit into
apple:mainfrom
adityasingh2400:fix-var-std-correction
Open

Fix torch var/std converter ignoring correction > 1 under tracing#2718
adityasingh2400 wants to merge 1 commit into
apple:mainfrom
adityasingh2400:fix-var-std-correction

Conversation

@adityasingh2400
Copy link
Copy Markdown

Bug

When a model is converted via torch.jit.trace, torch.var/torch.std called with the correction keyword lower to the same aten::var/aten::std signature as the older unbiased keyword. Both keywords end up in the same positional slot (aten::var(self, dim, <slot>, keepdim)), and the slot's constant is just an int.

The converter always read that slot as the boolean unbiased. So any correction >= 2 was truthy and silently treated as unbiased=True, i.e. correction=1. correction=0 and correction=1 happened to come out right only because they coincide with unbiased=False/True.

Repro (torch 2.12, coremltools from main):

import torch, coremltools as ct, numpy as np
torch.manual_seed(0)
x = torch.randn(3, 5)

class M(torch.nn.Module):
    def forward(self, x):
        return torch.var(x, dim=-1, correction=2, keepdim=True)

ts = torch.jit.trace(M().eval(), x)
m = ct.convert(ts, inputs=[ct.TensorType(name="x", shape=x.shape, dtype=np.float32)],
               compute_units=ct.ComputeUnit.CPU_ONLY, minimum_deployment_target=ct.target.iOS17)

print("torch ", M()(x).flatten()[:2].tolist())
print("coreml", np.asarray(list(m.predict({"x": x.numpy()}).values())[0]).flatten()[:2].tolist())
torch  [2.762..., 1.057...]   # divides by N - 2 = 3
coreml [2.072..., 0.793...]   # divides by N - 1 = 4  (wrong, == correction=1)

correction=3, 4, ... are all collapsed to the correction=1 result the same way. std has the same problem since it reuses this logic.

Fix

Read the argument as correction rather than unbiased. correction is a strict generalization: unbiased=False is correction=0 and unbiased=True is correction=1, so existing unbiased-traced graphs keep their current behavior, and correction >= 2 now divides by N - correction as PyTorch does. The value is routed through the existing _var(correction=...) path. The unbiased= kwarg (export) is still accepted and mapped onto correction.

Test

test_var_std_with_correction only parametrized correction over [0, 1], which is exactly the range the old code got right by accident, so the bug was invisible. Extended it to include correction=2. With this fix the full matrix (var/std x correction[0,1,2] x dim[[0,2],[1],[2]] x keepdim) passes; on main the correction=2 cases fail with a numerical mismatch.

Under torch.jit.trace, torch.var/std with the correction keyword lower to the
same aten::var/aten::std signature as the unbiased keyword: the value lands in
the same positional slot. The converter always read that slot as the bool
unbiased, so any correction >= 2 was silently treated as unbiased=True (i.e.
correction=1), producing a wrong variance/standard deviation. correction 0 and
1 happened to be correct because they coincide with unbiased False/True.

Read the argument as correction instead, which subsumes unbiased
(unbiased=False == correction 0, unbiased=True == correction 1), and route it
through _var(correction=...). Extend test_var_std_with_correction to cover
correction=2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant