Skip to content

Add VJP for cumulative max and min scans#3599

Open
devin-lai wants to merge 2 commits into
ml-explore:mainfrom
devin-lai:cummax-cummin-vjp
Open

Add VJP for cumulative max and min scans#3599
devin-lai wants to merge 2 commits into
ml-explore:mainfrom
devin-lai:cummax-cummin-vjp

Conversation

@devin-lai
Copy link
Copy Markdown

Summary

  • Implement reverse-mode VJP for cummax and cummin instead of raising at gradient time.
  • Route cotangents to the latest running-extreme owner in scan order, including reverse and exclusive scans.
  • Add Python and C++ autograd tests covering max/min, forward/reverse scans, inclusive/exclusive modes, ties, and non-uniform cotangents.

Details

cummax and cummin are public forward ops, but mx.grad through them previously raised because cumulative min/max VJP was not implemented. This follows the existing Scan::vjp TODO: mark entries equal to the inclusive running extreme, scan over those indices to recover the owning input position, then accumulate cotangents with scatter_add_axis.

For ties, the VJP uses the latest occurrence in scan order. That convention matches the owner selected by PyTorch cummax/cummin backward in the cases cross-checked locally. Exclusive scans first shift cotangents by one step in the scan direction because the current element is excluded from its own output.

The implementation adds a small number of scan ops plus one scatter_add_axis only on the VJP path, which previously threw, so it does not regress existing differentiable workloads.

Tests

  • /Users/ldy/Library/Python/3.11/bin/pre-commit run --files mlx/primitives.cpp python/tests/test_autograd.py tests/autograd_tests.cpp
  • PYTHONPATH=python:python/tests python3.11 -m unittest python.tests.test_autograd
  • build_tests/tests/tests --test-case="test scan grads"
  • Local PyTorch cross-check: 24 VJP comparisons across cummax/cummin, axes, reverse, inclusive/exclusive, ties, and weighted cotangents.

`mx.grad` through `cummax` and `cummin` previously raised because reverse-mode differentiation was not implemented for cumulative min/max scans.

Route each output cotangent to the input element that owns the running extreme at that position, using the latest occurrence in scan order for ties. The owner index is reconstructed from the inclusive running extreme and then accumulated with `scatter_add_axis`; exclusive scans shift cotangents by one step in the scan direction.

Add Python and C++ tests for max/min scans across forward, reverse, inclusive, and exclusive modes, including tie cases and non-uniform cotangents. Forward-mode JVP remains unimplemented, matching the existing `cumprod` behavior.
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Local PyTorch cross-check: 24 VJP comparisons across cummax/cummin, axes, reverse, inclusive/exclusive, ties, and weighted cotangents.

If cross-check had been done locally I think it would be very useful to run them in the python tests.

Move the local PyTorch comparison for cumulative max/min gradients into the Python autograd tests. The new test covers cummax and cummin over each axis, forward and reverse scans, inclusive and exclusive modes, tie cases, and weighted cotangents.

PyTorch does not expose MLX's reverse or exclusive scan options directly, so the test models reverse scans with a flip and exclusive scans by shifting the inclusive result by one step. The test is skipped when PyTorch is unavailable, matching the other PyTorch reference tests in the suite.
@devin-lai
Copy link
Copy Markdown
Author

  • Local PyTorch cross-check: 24 VJP comparisons across cummax/cummin, axes, reverse, inclusive/exclusive, ties, and weighted cotangents.

If cross-check had been done locally I think it would be very useful to run them in the python tests.

Thanks!!! I moved the local cross-check into the python autograd tests. It now covers cummax/cummin over both axes, reverse/inclusive/exclusive combinations, tie cases, and weighted cotangents, the test skips if PyTorch is not available, consistent with the other PyTorch reference tests in the suite.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants