Skip to content

Cast torch.multinomial output to int to match torch's int64 contract (fixes #2337)#2710

Open
LeSingh1 wants to merge 1 commit into
apple:mainfrom
LeSingh1:fix/multinomial-int-dtype
Open

Cast torch.multinomial output to int to match torch's int64 contract (fixes #2337)#2710
LeSingh1 wants to merge 1 commit into
apple:mainfrom
LeSingh1:fix/multinomial-int-dtype

Conversation

@LeSingh1
Copy link
Copy Markdown
Contributor

Summary

torch.multinomial(input, num_samples, ...) is documented to return a LongTensor (int64) of sampled indices. The current converter dispatches to MIL's random_categorical, which preserves the float dtype of its probs input. So a model that ends in torch.multinomial produces an fp16/fp32 output from the converted Core ML model, even though the user wrote code that expects integers.

Repro on current main:

class M(nn.Module):
    def forward(self, x):
        return torch.multinomial(x, num_samples=3, replacement=True)

x = torch.tensor([[0.1, 0.3, 0.5, 0.1], [0.5, 0.2, 0.2, 0.1]])
print(M().eval()(x).dtype)                       # torch.int64

prog = ct.convert(torch.jit.trace(M().eval(), x),
                  inputs=[ct.TensorType(name="x", shape=x.shape)],
                  convert_to="milinternal")
print(prog.functions["main"].outputs[0].dtype)   # double / fp32 — wrong

That mismatch silently breaks anyone downstream of multinomial (e.g. gather/index_select over a vocab table after sampling). Tracked in #2337.

Fix

Add an explicit mb.cast(..., dtype="int32") after random_categorical so the converted output dtype matches torch:

samples = mb.random_categorical(x=x, size=num_samples, mode="probs")
samples = mb.cast(x=samples, dtype="int32", name=node.name)
context.add(samples)

int32 is the integer dtype Core ML programs use for indices throughout the converter (matching the existing randint, argmax, etc. paths).

Tests

Adds TestMultinomial::test_multinomial_returns_int:

  • Converts a torch.multinomial(x, 3, replacement=True) model with convert_to="milinternal" (runs without BlobWriter, as the structural assertion is the goal).
  • Asserts the program contains both random_categorical and cast.
  • Asserts the final output dtype is an integer type via types.is_int(out.dtype).

Passes locally:

PASSED test_multinomial_returns_int

The existing TestMultinomial cases (test_multinomial, test_multinomial_probs_instead_of_logits, test_multinomial_not_supported) are unaffected by the cast — they either don't assert on dtype, or compare counts where the int cast is a no-op semantically.

Issue

Fixes #2337

`torch.multinomial(input, num_samples, ...)` is documented to return a
LongTensor (int64) of sampled indices. The converter dispatched to MIL's
`random_categorical`, which preserves the float dtype of its `probs`
input. That meant any downstream op that expected integer indices
(e.g. `index_select`, `gather`, or simple Python-side `argmax`-style
math on the result) silently received fp16/fp32 — and in some cases the
conversion of the consumer would fail later for non-obvious reasons.

Add an explicit `mb.cast(..., dtype="int32")` after the
`random_categorical` call so the output dtype matches torch.

Adds `TestMultinomial::test_multinomial_returns_int` which builds the
program with `convert_to="milinternal"` and asserts the output dtype
is int. Runs without BlobWriter; existing TestMultinomial tests still
pass.

Fixes apple#2337
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

random_categorical returns float when it should return int

1 participant