-
Notifications
You must be signed in to change notification settings - Fork 795
NXP backend: added aten.split support #16276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NXP backend: added aten.split support #16276
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16276
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: nxp" |
|
@pytorchbot label "module: nxp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for the aten.split operator in the NXP backend by introducing a new pass that decomposes split operations into slice operations. This enables the backend to handle tensor splitting operations that were previously unsupported.
- Implements
DecomposeSplitToSlicesPassto convert split operations into equivalent slice operations - Adds comprehensive test coverage for different split scenarios including split with size, split with sections, GRU-based splits, and single-chunk edge cases
- Integrates the new pass into the default NXP backend pass pipeline
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| backends/nxp/aten_passes/decompose_split_to_slices_pass.py | Implements the core pass logic to decompose split operations into slice operations |
| backends/nxp/aten_passes/neutron_aten_pass_manager.py | Integrates the new decompose split pass into the default pass pipeline |
| backends/nxp/tests/models.py | Adds test model classes (GRUModel, SplitWithSize, SplitWithSections) to support split operation testing |
| backends/nxp/tests/test_decompose_split_to_slices.py | Provides comprehensive test coverage for the split decomposition functionality |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
b620f90 to
f49aa4a
Compare
|
Fixed the issues found by @MartinPavella, please review them. Thank you! |
|
Issues fixed. Thanks for the review and please re-review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pytest.param((8,), 3, 0, id="1D."), | ||
| pytest.param((4, 8), 5, 1, id="2D."), |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test ID has a trailing period which is inconsistent with other test IDs in the same parametrization. The ID should be "2D" instead of "2D." for consistency.
| pytest.param((8,), 3, 0, id="1D."), | |
| pytest.param((4, 8), 5, 1, id="2D."), | |
| pytest.param((8,), 3, 0, id="1D"), | |
| pytest.param((4, 8), 5, 1, id="2D"), |
| @pytest.mark.parametrize( | ||
| "input_shape, split_size, dim", | ||
| [ | ||
| pytest.param((8,), 3, 0, id="1D."), |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test ID has a trailing period which is inconsistent with other test IDs in the same parametrization. The ID should be "1D" instead of "1D." for consistency.
| @pytest.mark.parametrize( | ||
| "input_shape, sections, dim", | ||
| [ | ||
| pytest.param((8,), [5, 3], 0, id="1D."), |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test ID has a trailing period which is inconsistent with other test IDs in the parametrization. The ID should be "1D" instead of "1D." for consistency.
| "input_shape, sections, dim", | ||
| [ | ||
| pytest.param((8,), [5, 3], 0, id="1D."), | ||
| pytest.param((4, 8), [3, 3, 2], 1, id="2D."), |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test ID has a trailing period which is inconsistent with other test IDs in the parametrization. The ID should be "2D" instead of "2D." for consistency.
| split_nodes_chunks = list(split_nodes_chunks) | ||
|
|
||
| if not isinstance(split_nodes_chunks, list): | ||
| raise RuntimeError("Faulty split chunks") |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message "Faulty split chunks" is unclear and unhelpful for debugging. Consider providing more context about what went wrong, such as "Expected split chunks to be a list or tuple, but got {type(split_nodes_chunks)}".
| raise RuntimeError("Faulty split chunks") | |
| raise RuntimeError( | |
| f"Expected split chunks to be a list or tuple, but got " | |
| f"{type(split_nodes_chunks).__name__}: {split_nodes_chunks!r}" | |
| ) |
|
|
||
| # Check if split is even necessary - if not, remove it | ||
| if len(split_nodes_chunks) == 1: | ||
| getitem_node = list(split_node.users)[0] |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line assumes that the split node has exactly one user when there's only one chunk, but this assumption isn't validated. If the split node has zero or multiple users, accessing list(split_node.users)[0] could raise an IndexError or retrieve the wrong node. Consider adding validation or using a safer access pattern.
| getitem_node = list(split_node.users)[0] | |
| split_users = list(split_node.users) | |
| # Only apply this optimization when the split node has exactly one user. | |
| # If there are zero or multiple users, we cannot safely assume which one | |
| # should be rewired to the input, so we skip this transformation. | |
| if len(split_users) != 1: | |
| continue | |
| getitem_node = split_users[0] |
|
@novak-vaclav Please do not use |
b387e40 to
07d1da0
Compare
Summary
adds support for aten.split operator
Test plan
tests can be manually run using
pytest -c /dev/null backends/nxp/tests/cc @robert-kalmar @MartinPavella