Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,35 @@ def L_op(self, inputs, outputs, output_gradients):
if self.mode == "add":
return [cumsum(gi[reverse_slicing], self.axis)[reverse_slicing]]
elif self.mode == "mul":
fx = cumprod(x, axis=self.axis)
return [cumsum((fx * gi)[reverse_slicing], self.axis)[reverse_slicing] / x]
# The naive formula: cumsum_reverse(cumprod(x) * g) / x
# gives 0/0 = NaN when x contains zeros. We handle zeros
# by splitting into cases based on cumulative zero count.
axis = self.axis

is_zero = pt_eq(x, 0)
x_safe = switch(is_zero, ptb.ones_like(x), x)
fx_safe = cumprod(x_safe, axis=axis)

# Cumulative zero count along axis
is_zero_float = switch(is_zero, ptb.ones_like(x), ptb.zeros_like(x))
cum_zeros = cumsum(is_zero_float, axis=axis)

# True cumprod: zero wherever any zero has been seen
fx = fx_safe * pt_eq(cum_zeros, 0)

# Gradient for non-zero positions (0 at and after zeros)
naive_grad = (
cumsum((fx * gi)[reverse_slicing], axis)[reverse_slicing] / x_safe
)

# Gradient for first-zero positions: mask out contributions
# from positions after the second zero
h_masked = (fx_safe * gi) * lt(cum_zeros, 2)
zero_grad = cumsum(h_masked[reverse_slicing], axis)[reverse_slicing]

# Combine: use zero_grad at first-zero positions, naive elsewhere
first_zero = is_zero * pt_eq(cum_zeros, 1)
return [switch(first_zero, zero_grad, naive_grad)]
else:
raise NotImplementedError(
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
Expand Down
38 changes: 38 additions & 0 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,44 @@ def test_grad(self):
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)
utt.verify_grad(self.op_class(axis=axis, mode="mul"), [a], eps=4e-4)

def test_grad_cumprod_with_zeros(self):
# Regression test: cumprod gradient must handle zeros correctly.
# The naive formula (cumprod(x)*g / x) gives 0/0 = NaN at zero positions.
x = vector("x", dtype="float64")
y = cumprod(x, axis=0)
g = pytensor.grad(y.sum(), x)
f = pytensor.function([x], g)

# Single zero in the middle
result = f(np.array([3.0, 0.0, 5.0]))
expected = np.array([1.0, 18.0, 0.0])
np.testing.assert_allclose(result, expected)

# Zero at the beginning
result = f(np.array([0.0, 7.0, 3.0]))
expected = np.array([29.0, 0.0, 0.0])
np.testing.assert_allclose(result, expected)

# Multiple zeros
result = f(np.array([3.0, 0.0, 0.0, 5.0]))
expected = np.array([1.0, 3.0, 0.0, 0.0])
np.testing.assert_allclose(result, expected)

# All zeros
result = f(np.array([0.0, 0.0, 0.0]))
expected = np.array([1.0, 0.0, 0.0])
np.testing.assert_allclose(result, expected)

# 2D with zeros, axis=1
x2 = dmatrix("x2")
y2 = cumprod(x2, axis=1)
g2 = pytensor.grad(y2.sum(), x2)
f2 = pytensor.function([x2], g2)

result2 = f2(np.array([[3.0, 0.0, 5.0], [7.0, 4.0, 0.0]]))
expected2 = np.array([[1.0, 18.0, 0.0], [5.0, 7.0, 28.0]])
np.testing.assert_allclose(result2, expected2)


class TestBinCount(utt.InferShapeTester):
@pytest.mark.parametrize(
Expand Down
Loading