Skip to content

Conversation

@justinchuby
Copy link
Collaborator

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby added the merge at lgtm Reviewers can merge when they approve label Jan 22, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements the prims::sum operation to fix issue #173074. The implementation adds support for tensor sum reduction with optional dimension specification and output dtype casting.

Changes:

  • Implements prims_sum function using ONNX ReduceSum operator with optional dtype casting
  • Adds end-to-end test for torch.std_mean which decomposes into prims.sum

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxscript/function_libs/torch_lib/ops/prims.py Implements prims_sum function with ReduceSum and optional Cast operations
tests/function_libs/torch_lib/e2e_ops_tests.py Adds test case for torch.std_mean to validate the prims.sum implementation

inp: TensorType, dims: Optional[Sequence[int]], output_dtype: Optional[int] = None
) -> TensorType:
"""sum(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor"""

Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dims parameter should be normalized to None when it's an empty sequence, following the pattern used in prims_var (lines 839-841). ONNX's ReduceSum operator may not handle empty sequences correctly, and converting empty sequences to None ensures the operator reduces over all dimensions as expected.

Suggested change
if dims is not None and len(dims) == 0:
dims = None

Copilot uses AI. Check for mistakes.
@codecov
Copy link

codecov bot commented Jan 22, 2026

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
12883 1 12882 1191
View the top 1 failed test(s) by shortest run time
onnxscript.tools.memory_peak_test.TestMemoryPeak::test_spy
Stack Traces | 0.057s run time
worker 'gw1' crashed while running 'onnxscript/tools/memory_peak_test.py::TestMemoryPeak::test_spy'

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@justinchuby justinchuby merged commit 7ec5d25 into main Jan 22, 2026
30 of 33 checks passed
@justinchuby justinchuby deleted the justinchu/prims-sum branch January 22, 2026 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge at lgtm Reviewers can merge when they approve module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

Can't convert std_mean to onnx (prims.sum.default)

3 participants