Skip to content

Add torch op coverage for LLM attention mask construction#2668

Merged
TobyRoseman merged 1 commit intoapple:mainfrom
john-rocky:llm-attention-mask-op-coverage
Apr 13, 2026
Merged

Add torch op coverage for LLM attention mask construction#2668
TobyRoseman merged 1 commit intoapple:mainfrom
john-rocky:llm-attention-mask-op-coverage

Conversation

@john-rocky
Copy link
Copy Markdown
Contributor

@john-rocky john-rocky commented Apr 10, 2026

Summary

Adds the small set of torch ops that HuggingFace attention-mask construction emits via torch.export but coremltools didn't yet handle. Discovered while converting google/gemma-4-E2B-it, but the same gaps affect any decoder LLM that builds masks the way HuggingFace's eager attention path does.

All five gaps are independent fixes for existing asymmetries in the registry (e.g. bitwise_and exists but bitwise_or did not; new_zeros exists but new_ones did not). None of them change behavior of any op that already worked.

Reproduction (before this PR)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import coremltools as ct

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-E2B-it",
    torch_dtype=torch.float32,
    attn_implementation="eager",
).eval()
tok = AutoTokenizer.from_pretrained("google/gemma-4-E2B-it")
enc = tok("hello", return_tensors="pt", padding="max_length", max_length=32)

class Wrapper(torch.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, ids, attn):
        return self.m(input_ids=ids, attention_mask=attn, use_cache=False).logits

with torch.no_grad():
    ep = torch.export.export(
        Wrapper(model),
        args=(enc["input_ids"], enc["attention_mask"]),
        strict=False,
    ).run_decompositions({})

ct.convert(ep, convert_to="mlprogram", minimum_deployment_target=ct.target.iOS18)
# NotImplementedError: Unsupported fx node or_1, kind __or__

After fixing __or__ the converter then fails on __and__, then new_ones, then where.scalarother, then bitwise_and with mixed bool+float inputs. This PR addresses all of them in one place.

Changes

  1. sanitize_op_kind — strip the __name__ wrapper after the namespace and overload suffix have been removed. Previously aten::__or__.Tensor sanitized to __or__ instead of or, so the registry lookup missed even when an or handler existed. The original logic only stripped __...__ if the whole op kind matched, which is only the case for the legacy TorchScript form.

  2. Register bitwise_or / bitwise_xor with or / xor aliases for the post-sanitize form. Both reuse the logical_* MIL primitives, mirroring the existing bitwise_and registration. The asymmetry was clearly accidental.

  3. Relax bitwise_and / bitwise_or / bitwise_xor to accept mixed-bool inputs. Previously they required all inputs to be bool; now if at least one input is bool we cast both to bool and lower to the corresponding logical_* op. Pure non-bool inputs are still rejected with the same error so genuine integer bitwise math is unchanged. This unblocks Gemma-style mask combination where a bool causal mask meets a float padding mask.

  4. Register aten::new_ones mirroring the existing new_zeros, using _make_fill_op so float-typed shape inputs from torch.export are coerced to int32 (which is also a latent issue in new_zeros that this PR avoids by using the safer helper).

  5. Add where.ScalarOther as an alias on the existing where handler, which already does dtype promotion and broadcasting. The where.self and where.scalarself overloads were already covered; where.scalarother was the missing third.

Test plan

  • New unit tests in test_internal_graph.py covering the dunder-after-namespace case for sanitize_op_kind:
    • aten::__or__.Tensoror
    • aten::__and__.Tensorand
    • aten::__xor__.Tensorxor
    • Plus the existing legacy forms (__add__add, aten::add.Tensoradd, etc.) to guard against regressions
  • New op-level tests in test_torch_ops.py:
    • TestNewOnes::test_new_ones_static
    • TestBitwiseOr::test_bitwise_or and test_or_operator (the latter exercises the tensor | tensor form)
    • TestBitwiseXor::test_bitwise_xor
  • Existing TestBitwiseAnd and all TestLogical* tests still pass (the relaxed bool check is a strict superset of the previous behavior)

End-to-end validation on google/gemma-4-E2B-it (compute_precision=FLOAT32):

ref last argmax: 7001 -> ' France'
cml last argmax: 7001 -> ' France'
top-5 overlap: 5/5
per-position argmax agreement: 100.0% (5/5)
max abs diff: 5.30e-02
VERDICT: Core ML model is functionally correct.

(fp16 correctness for the same model is addressed by a separate PR: #2669 — no functional dependency between the two PRs, but both are needed to convert Gemma 4 cleanly.)

Adds the small set of torch ops that HuggingFace attention-mask code
emits via torch.export but coremltools didn't yet handle, exposed
while converting google/gemma-4-E2B-it:

- Register bitwise_or / bitwise_xor (with `or` / `xor` aliases for the
  post-sanitize form). The existing bitwise_and was the only registered
  member of the family; this restores symmetry. Both new handlers reuse
  the logical_* MIL primitives, matching the existing bitwise_and pattern.
- Relax bitwise_and / bitwise_or / bitwise_xor to accept mixed-bool
  inputs (cast both to bool when at least one is bool). Pure non-bool
  inputs are still rejected with the same error so genuine integer
  bitwise math is unchanged. This unblocks Gemma-style mask combination
  where a bool causal mask meets a float padding mask.
- Register aten::new_ones mirroring the existing new_zeros, using
  _make_fill_op so float-typed shape inputs from torch.export are
  coerced to int32.
- Add where.ScalarOther as an alias on the existing where handler
  (which already does dtype promotion and broadcasting).
- Fix sanitize_op_kind so the `__name__` wrapper is also stripped after
  the namespace and overload suffix have been removed. Previously
  aten::__or__.Tensor sanitized to "__or__" instead of "or", making the
  registry lookup miss even when an "or" handler existed.

Tests:
- Unit tests for sanitize_op_kind covering the dunder-after-namespace
  case in test_internal_graph.py.
- Op-level tests for new_ones, bitwise_or, bitwise_xor and the
  `tensor | tensor` operator form in test_torch_ops.py.

Validated end-to-end on google/gemma-4-E2B-it: torch.export ->
ct.convert -> mlprogram now succeeds and the fp32 model output
matches the PyTorch reference (top-5 5/5, per-position argmax 100%,
max abs diff 0.05).
@john-rocky john-rocky marked this pull request as ready for review April 10, 2026 02:26
@TobyRoseman
Copy link
Copy Markdown
Collaborator

@TobyRoseman TobyRoseman merged commit a74c68c into apple:main Apr 13, 2026
@TobyRoseman
Copy link
Copy Markdown
Collaborator

Thanks for your contribution @john-rocky.

@john-rocky john-rocky deleted the llm-attention-mask-op-coverage branch April 13, 2026 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants