-
Notifications
You must be signed in to change notification settings - Fork 98
[torchlib] prims::sum #2778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchlib] prims::sum #2778
Conversation
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements the prims::sum operation to fix issue #173074. The implementation adds support for tensor sum reduction with optional dimension specification and output dtype casting.
Changes:
- Implements
prims_sumfunction using ONNX ReduceSum operator with optional dtype casting - Adds end-to-end test for
torch.std_meanwhich decomposes intoprims.sum
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/ops/prims.py | Implements prims_sum function with ReduceSum and optional Cast operations |
| tests/function_libs/torch_lib/e2e_ops_tests.py | Adds test case for torch.std_mean to validate the prims.sum implementation |
| inp: TensorType, dims: Optional[Sequence[int]], output_dtype: Optional[int] = None | ||
| ) -> TensorType: | ||
| """sum(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor""" | ||
|
|
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dims parameter should be normalized to None when it's an empty sequence, following the pattern used in prims_var (lines 839-841). ONNX's ReduceSum operator may not handle empty sequences correctly, and converting empty sequences to None ensures the operator reduces over all dimensions as expected.
| if dims is not None and len(dims) == 0: | |
| dims = None |
❌ 1 Tests Failed:
View the top 1 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
Fix pytorch/pytorch#173074