Add torch op coverage for LLM attention mask construction#2668
Merged
TobyRoseman merged 1 commit intoapple:mainfrom Apr 13, 2026
Merged
Add torch op coverage for LLM attention mask construction#2668TobyRoseman merged 1 commit intoapple:mainfrom
TobyRoseman merged 1 commit intoapple:mainfrom
Conversation
Adds the small set of torch ops that HuggingFace attention-mask code emits via torch.export but coremltools didn't yet handle, exposed while converting google/gemma-4-E2B-it: - Register bitwise_or / bitwise_xor (with `or` / `xor` aliases for the post-sanitize form). The existing bitwise_and was the only registered member of the family; this restores symmetry. Both new handlers reuse the logical_* MIL primitives, matching the existing bitwise_and pattern. - Relax bitwise_and / bitwise_or / bitwise_xor to accept mixed-bool inputs (cast both to bool when at least one is bool). Pure non-bool inputs are still rejected with the same error so genuine integer bitwise math is unchanged. This unblocks Gemma-style mask combination where a bool causal mask meets a float padding mask. - Register aten::new_ones mirroring the existing new_zeros, using _make_fill_op so float-typed shape inputs from torch.export are coerced to int32. - Add where.ScalarOther as an alias on the existing where handler (which already does dtype promotion and broadcasting). - Fix sanitize_op_kind so the `__name__` wrapper is also stripped after the namespace and overload suffix have been removed. Previously aten::__or__.Tensor sanitized to "__or__" instead of "or", making the registry lookup miss even when an "or" handler existed. Tests: - Unit tests for sanitize_op_kind covering the dunder-after-namespace case in test_internal_graph.py. - Op-level tests for new_ones, bitwise_or, bitwise_xor and the `tensor | tensor` operator form in test_torch_ops.py. Validated end-to-end on google/gemma-4-E2B-it: torch.export -> ct.convert -> mlprogram now succeeds and the fp32 model output matches the PyTorch reference (top-5 5/5, per-position argmax 100%, max abs diff 0.05).
Collaborator
TobyRoseman
approved these changes
Apr 13, 2026
Collaborator
|
Thanks for your contribution @john-rocky. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the small set of torch ops that HuggingFace attention-mask construction emits via
torch.exportbut coremltools didn't yet handle. Discovered while convertinggoogle/gemma-4-E2B-it, but the same gaps affect any decoder LLM that builds masks the way HuggingFace'seagerattention path does.All five gaps are independent fixes for existing asymmetries in the registry (e.g.
bitwise_andexists butbitwise_ordid not;new_zerosexists butnew_onesdid not). None of them change behavior of any op that already worked.Reproduction (before this PR)
After fixing
__or__the converter then fails on__and__, thennew_ones, thenwhere.scalarother, thenbitwise_andwith mixed bool+float inputs. This PR addresses all of them in one place.Changes
sanitize_op_kind— strip the__name__wrapper after the namespace and overload suffix have been removed. Previouslyaten::__or__.Tensorsanitized to__or__instead ofor, so the registry lookup missed even when anorhandler existed. The original logic only stripped__...__if the whole op kind matched, which is only the case for the legacy TorchScript form.Register
bitwise_or/bitwise_xorwithor/xoraliases for the post-sanitize form. Both reuse thelogical_*MIL primitives, mirroring the existingbitwise_andregistration. The asymmetry was clearly accidental.Relax
bitwise_and/bitwise_or/bitwise_xorto accept mixed-bool inputs. Previously they required all inputs to be bool; now if at least one input is bool we cast both to bool and lower to the correspondinglogical_*op. Pure non-bool inputs are still rejected with the same error so genuine integer bitwise math is unchanged. This unblocks Gemma-style mask combination where a bool causal mask meets a float padding mask.Register
aten::new_onesmirroring the existingnew_zeros, using_make_fill_opso float-typed shape inputs fromtorch.exportare coerced to int32 (which is also a latent issue innew_zerosthat this PR avoids by using the safer helper).Add
where.ScalarOtheras an alias on the existingwherehandler, which already does dtype promotion and broadcasting. Thewhere.selfandwhere.scalarselfoverloads were already covered;where.scalarotherwas the missing third.Test plan
test_internal_graph.pycovering the dunder-after-namespace case forsanitize_op_kind:aten::__or__.Tensor→oraten::__and__.Tensor→andaten::__xor__.Tensor→xor__add__→add,aten::add.Tensor→add, etc.) to guard against regressionstest_torch_ops.py:TestNewOnes::test_new_ones_staticTestBitwiseOr::test_bitwise_orandtest_or_operator(the latter exercises thetensor | tensorform)TestBitwiseXor::test_bitwise_xorTestBitwiseAndand allTestLogical*tests still pass (the relaxed bool check is a strict superset of the previous behavior)End-to-end validation on
google/gemma-4-E2B-it(compute_precision=FLOAT32):(
fp16correctness for the same model is addressed by a separate PR: #2669 — no functional dependency between the two PRs, but both are needed to convert Gemma 4 cleanly.)