Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4262,8 +4262,57 @@ std::vector<array> Scan::vjp(
where(eq_zero, z, grad, stream()),
stream())};
} else {
// Can probably be implemented by equals and then cummax to make the mask
throw std::runtime_error("VJP is not implemented for cumulative min/max");
// Cumulative max/min: route each position's cotangent to the input element
// that holds the running extreme there, breaking ties toward the latest
// element in scan order. The owning index is reconstructed by marking
// where the input equals the inclusive running extreme and taking a running
// extreme over those indices (cummax picks the largest tied index for a
// forward scan and cummin the smallest for a reverse scan). The cotangents
// are then scatter-added. Exclusive scans skip the current element, so the
// cotangents are first shifted by one in the scan direction.
auto in = primals[0];
auto cotan = cotangents[0];
int n = in.shape(axis_);
auto s = stream();

auto running_extreme = outputs[0];
if (!inclusive_) {
running_extreme = (reduce_type_ == Scan::Max)
? cummax(in, axis_, reverse_, /* inclusive = */ true, s)
: cummin(in, axis_, reverse_, /* inclusive = */ true, s);
}

Shape index_shape(in.ndim(), 1);
index_shape[axis_] = n;
auto iota =
reshape(arange(static_cast<double>(n), int32, s), index_shape, s);
auto masked = where(
equal(in, running_extreme, s),
iota,
array(reverse_ ? n : -1, int32),
s);
auto owner = astype(
reverse_ ? cummin(masked, axis_, /* reverse = */ true, true, s)
: cummax(masked, axis_, /* reverse = */ false, true, s),
uint32,
s);

if (!inclusive_) {
Shape pad_shape = in.shape();
pad_shape[axis_] = 1;
auto zero = zeros(pad_shape, cotan.dtype(), s);
Shape start(in.ndim(), 0);
Shape stop = in.shape();
if (reverse_) {
stop[axis_] = n - 1;
cotan = concatenate({zero, slice(cotan, start, stop, s)}, axis_, s);
} else {
start[axis_] = 1;
cotan = concatenate({slice(cotan, start, stop, s), zero}, axis_, s);
}
}

return {scatter_add_axis(zeros_like(in, s), owner, cotan, axis_, s)};
}
}

Expand Down
171 changes: 171 additions & 0 deletions python/tests/test_autograd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
# Copyright © 2023 Apple Inc.

import gc
import itertools
import unittest

import mlx.core as mx
import mlx_tests
import numpy as np

try:
import torch

has_torch = True
except ImportError:
has_torch = False


class TestAutograd(mlx_tests.MLXTestCase):
Expand Down Expand Up @@ -590,6 +599,168 @@ def fun(y):
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
self.assertTrue(mx.allclose(out, expected))

def test_cummax_grad(self):
# Ties route to the latest occurrence, matching the cummax indices.
a = mx.array([3.0, 3.0, 1.0, 5.0, 5.0])

def fun(y):
return mx.cummax(y).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([1.0, 2.0, 0.0, 1.0, 1.0]))
)

def fun(y):
return mx.cummax(y, inclusive=False).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([1.0, 2.0, 0.0, 1.0, 0.0]))
)

def fun(y):
return mx.cummax(y, reverse=True).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([0.0, 0.0, 0.0, 4.0, 1.0]))
)

def fun(y):
return mx.cummax(y, reverse=True, inclusive=False).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([0.0, 0.0, 0.0, 3.0, 1.0]))
)

# Non-uniform cotangents are routed to the owning index.
cot = mx.array([10.0, 1.0, 1.0, 100.0, 1000.0])
_, vjps = mx.vjp(lambda y: mx.cummax(y), (a,), (cot,))
self.assertTrue(mx.allclose(vjps[0], mx.array([10.0, 2.0, 0.0, 100.0, 1000.0])))

# 2D along an inner axis.
m = mx.array([[1.0, 3.0, 3.0], [4.0, 2.0, 4.0]])

def fun(y):
return mx.cummax(y, axis=1).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(m), mx.array([[1.0, 1.0, 1.0], [2.0, 0.0, 1.0]]))
)

def fun(y):
return mx.cummax(y, axis=1, inclusive=False).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(m), mx.array([[1.0, 1.0, 0.0], [2.0, 0.0, 0.0]]))
)

def test_cummin_grad(self):
a = mx.array([3.0, 3.0, 1.0, 5.0, 5.0])

def fun(y):
return mx.cummin(y).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([1.0, 1.0, 3.0, 0.0, 0.0]))
)

def fun(y):
return mx.cummin(y, inclusive=False).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([1.0, 1.0, 2.0, 0.0, 0.0]))
)

def fun(y):
return mx.cummin(y, reverse=True).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([0.0, 0.0, 3.0, 1.0, 1.0]))
)

def fun(y):
return mx.cummin(y, reverse=True, inclusive=False).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(a), mx.array([0.0, 0.0, 2.0, 1.0, 1.0]))
)

cot = mx.array([10.0, 1.0, 1.0, 100.0, 1000.0])
_, vjps = mx.vjp(lambda y: mx.cummin(y), (a,), (cot,))
self.assertTrue(mx.allclose(vjps[0], mx.array([10.0, 1.0, 1101.0, 0.0, 0.0])))

# 2D along the outer axis.
m = mx.array([[1.0, 3.0, 3.0], [4.0, 2.0, 4.0]])

def fun(y):
return mx.cummin(y, axis=0).sum()

self.assertTrue(
mx.allclose(mx.grad(fun)(m), mx.array([[2.0, 1.0, 2.0], [0.0, 1.0, 0.0]]))
)

@unittest.skipIf(not has_torch, "requires Torch")
def test_cummax_cummin_grad_vs_torch(self):
# Cross-check the cumulative max/min VJP against PyTorch autograd over
# axes, scan direction, inclusive/exclusive modes, ties, and weighted
# cotangents. Torch has no reverse or exclusive scan, so reverse is
# emulated by flipping along the axis and exclusive by shifting the
# inclusive scan one step (its leading element carries no gradient).
def torch_scan(x, axis, reverse, inclusive, op):
xf = torch.flip(x, [axis]) if reverse else x
scan = torch.cummax if op == "max" else torch.cummin
c = scan(xf, axis).values
if not inclusive:
n = c.size(axis)
head = torch.zeros_like(c.narrow(axis, 0, 1)) # constant, no grad
c = torch.cat([head, c.narrow(axis, 0, n - 1)], dim=axis)
return torch.flip(c, [axis]) if reverse else c

def mx_scan(z, axis, reverse, inclusive, op):
scan = mx.cummax if op == "max" else mx.cummin
return scan(z, axis=axis, reverse=reverse, inclusive=inclusive)

inputs = [
np.array([3.0, 3.0, 1.0, 5.0, 5.0, 1.0, 5.0, 2.0], dtype=np.float32),
np.array(
[[1.0, 3.0, 3.0, 2.0], [4.0, 2.0, 4.0, 4.0], [4.0, 3.0, 1.0, 4.0]],
dtype=np.float32,
),
]
rng = np.random.default_rng(0)
for x_np in inputs:
cotangents = {
"ones": np.ones_like(x_np),
"weighted": rng.uniform(1.0, 9.0, x_np.shape).astype(np.float32),
}
for axis in range(x_np.ndim):
for op, reverse, inclusive in itertools.product(
("max", "min"), (False, True), (True, False)
):
for cot_name, cot_np in cotangents.items():
with self.subTest(
shape=x_np.shape,
axis=axis,
op=op,
reverse=reverse,
inclusive=inclusive,
cotangent=cot_name,
):
_, (mx_grad,) = mx.vjp(
lambda z: mx_scan(z, axis, reverse, inclusive, op),
(mx.array(x_np),),
(mx.array(cot_np),),
)

xt = torch.tensor(x_np, requires_grad=True)
out = torch_scan(xt, axis, reverse, inclusive, op)
(out * torch.tensor(cot_np)).sum().backward()

self.assertTrue(
np.allclose(
np.array(mx_grad), xt.grad.numpy(), atol=1e-5
)
)

def test_topk_grad(self):
a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32)

Expand Down
62 changes: 62 additions & 0 deletions tests/autograd_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,68 @@ TEST_CASE("test scan grads") {
CHECK(array_equal(out, expected).item<bool>());
}

// Test cummax
{
int axis = 0;
int reverse = false;
int inclusive = true;
auto fun = [&axis, &reverse, &inclusive](array x) {
return cummax(x, axis, reverse, inclusive);
};

auto x = array({3.0f, 3.0f, 1.0f, 5.0f, 5.0f}, {5});
auto g = ones({5});
auto out = vjp(fun, x, g).second;
auto expected = array({1.0f, 2.0f, 0.0f, 1.0f, 1.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());

reverse = true;
out = vjp(fun, x, g).second;
expected = array({0.0f, 0.0f, 0.0f, 4.0f, 1.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());

inclusive = false;
out = vjp(fun, x, g).second;
expected = array({0.0f, 0.0f, 0.0f, 3.0f, 1.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());

reverse = false;
out = vjp(fun, x, g).second;
expected = array({1.0f, 2.0f, 0.0f, 1.0f, 0.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());
}

// Test cummin
{
int axis = 0;
int reverse = false;
int inclusive = true;
auto fun = [&axis, &reverse, &inclusive](array x) {
return cummin(x, axis, reverse, inclusive);
};

auto x = array({3.0f, 3.0f, 1.0f, 5.0f, 5.0f}, {5});
auto g = ones({5});
auto out = vjp(fun, x, g).second;
auto expected = array({1.0f, 1.0f, 3.0f, 0.0f, 0.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());

reverse = true;
out = vjp(fun, x, g).second;
expected = array({0.0f, 0.0f, 3.0f, 1.0f, 1.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());

inclusive = false;
out = vjp(fun, x, g).second;
expected = array({0.0f, 0.0f, 2.0f, 1.0f, 1.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());

reverse = false;
out = vjp(fun, x, g).second;
expected = array({1.0f, 1.0f, 2.0f, 0.0f, 0.0f}, {5});
CHECK(array_equal(out, expected).item<bool>());
}

// Test cumsum jvp
{
int axis = 0;
Expand Down
Loading