Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9146,8 +9146,12 @@ def multinomial(context, node):
if num_samples > 1 and not replacement:
raise ValueError("When num_samples is larger than 1, only replacement=True is supported.")
# Based on PyTorch documentations, the input to `torch.multinomial` is probability, not logit.
x = mb.random_categorical(x=x, size=num_samples, mode="probs", name=node.name)
context.add(x)
samples = mb.random_categorical(x=x, size=num_samples, mode="probs")
# `torch.multinomial` returns int64-valued indices; `mb.random_categorical` keeps the
# input float dtype. Without this cast the converted model emits floats where downstream
# ops (e.g. `gather`/`index_select`) expect integer indices, matching #2337.
samples = mb.cast(x=samples, dtype="int32", name=node.name)
context.add(samples)


@register_torch_op
Expand Down
36 changes: 36 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14615,6 +14615,42 @@ def forward(self, x):
# The counting of 1 in PyTorch and CoreML output should be similar.
assert np.abs(np.sum(mlmodel_out) - np.sum(torch_out)) / mlmodel_out.size < 0.05

def test_multinomial_returns_int(self):
"""Regression test for https://github.com/apple/coremltools/issues/2337.
`torch.multinomial` returns an int64 tensor of indices; the converter
used to emit the raw float output of MIL `random_categorical`, which
broke downstream `gather` / `index_select` consumers. The converted
program must end with an int cast so the output dtype matches torch.
"""

class TestModel(nn.Module):
def forward(self, x):
return torch.multinomial(x, num_samples=3, replacement=True)

model = TestModel().eval()
input_data = torch.tensor([[0.1, 0.3, 0.5, 0.1], [0.5, 0.2, 0.2, 0.1]])
torch_model = torch.jit.trace(model, input_data)

prog = ct.convert(
torch_model,
inputs=[ct.TensorType(name="x", shape=input_data.shape)],
convert_to="milinternal",
minimum_deployment_target=ct.target.iOS17,
)

ops = get_op_types_in_program(prog)
assert "random_categorical" in ops
assert "cast" in ops, (
"Expected an int cast after `random_categorical` so the output "
"dtype matches torch.multinomial's int64 contract (issue #2337)."
)

# The final output must be an integer dtype.
out = prog.functions["main"].outputs[0]
assert types.is_int(out.dtype), (
f"multinomial output dtype must be int, got {out.dtype}"
)

@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(compute_units, backends),
Expand Down