Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pytensor/tensor/rewriting/linalg/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor
from pytensor.tensor.linalg.decomposition.qr import QR
from pytensor.tensor.linalg.decomposition.svd import SVD
from pytensor.tensor.linalg.inverse import MatrixInverse
from pytensor.tensor.linalg.summary import SLogDet, det
from pytensor.tensor.math import Prod, log, prod
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -192,6 +193,16 @@ def det_of_triangular(fgraph, node):
return [det_val]


@register_canonicalize
@register_stabilize
@node_rewriter([det])
def det_of_inv(fgraph, node):
"""Replace det(matrix_inverse(X)) with reciprocal(det(X))."""
match node.inputs[0].owner_op_and_inputs:
case (Blockwise(MatrixInverse()), X):
return [1 / det(X)]


@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
Expand Down
62 changes: 61 additions & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i, specify_shape
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.subtensor import Subtensor, _is_provably_positive
from pytensor.tensor.type import (
complex_dtypes,
uint_dtypes,
Expand Down Expand Up @@ -689,6 +689,66 @@ def local_exp_log_nan_switch(fgraph, node):
return [new_out]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_div(fgraph, node):
"""Rewrite log(reciprocal(x)) -> -log(x) and log(a / b) -> log(a) - log(b).

A reciprocal is just ``1 / x``; log(a / b) only splits when a positive
constant operand is involved, so its log folds and the op count stays flat.
"""
(inp,) = node.inputs
if not (inp.owner and isinstance(inp.owner.op, Elemwise)):
return None
scalar_op = inp.owner.op.scalar_op

if isinstance(scalar_op, ps.Reciprocal):
return [neg(log(inp.owner.inputs[0]))]

if isinstance(scalar_op, ps.TrueDiv):
num, den = inp.owner.inputs
if (isinstance(num, Constant) and _is_provably_positive(num, strict=True)) or (
isinstance(den, Constant) and _is_provably_positive(den, strict=True)
):
return [log(num) - log(den)]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([sign])
def local_sign_div(fgraph, node):
"""Rewrite sign of a reciprocal or division from a known-sign operand.

sign(reciprocal(x)) -> sign(x). For sign(a / b): a provably positive side ->
``sign(other)``; a negative constant side -> ``-sign(other)``. Bails out
otherwise.
"""
(inp,) = node.inputs
if not (inp.owner and isinstance(inp.owner.op, Elemwise)):
return None
scalar_op = inp.owner.op.scalar_op

if isinstance(scalar_op, ps.Reciprocal):
return [sign(inp.owner.inputs[0])]

if not isinstance(scalar_op, ps.TrueDiv):
return None

num, den = inp.owner.inputs

if _is_provably_positive(num, strict=True):
return [sign(den)]
if _is_provably_positive(den, strict=True):
return [sign(num)]
Comment on lines +742 to +745
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this same name as what is imported? Github rendering as duplicated


for side, other in ((num, den), (den, num)):
if isinstance(side, Constant) and np.all(np.asarray(side.data) < 0):
return [neg(sign(other))]


@register_canonicalize
@register_specialize
@node_rewriter([Sum])
Expand Down
35 changes: 35 additions & 0 deletions tests/tensor/rewriting/linalg/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,38 @@ def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn):
expected = expected_fn(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])


def test_det_of_inv():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this in test_summary?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because a determinant is a summary of a matrix

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah this should clearly be linalg/volume.py

x = pt.tensor("x", shape=(3, 3))
out = det(pt.linalg.inv(x))
expected = pt.as_tensor(1.0, dtype="float64") / det(x)
rewritten = rewrite_graph(out, include=["canonicalize", "stabilize"])
assert_equal_computations([rewritten], [expected])


def test_slogdet_of_inv():
x = pt.dmatrix("x")
# slogdet(inv(x)) -> (sign, logabsdet)
sign_inv, logabsdet_inv = pt.linalg.slogdet(pt.linalg.inv(x))

# expected: (sign(det(x)), -logabsdet(det(x)))
# det(inv(x)) = 1/det(x), so sign is same.
# logabsdet(inv(x)) = log(abs(1/det(x))) = -log(abs(det(x)))
sign_x, logabsdet_x = pt.linalg.slogdet(x)
expected_sign = sign_x
expected_logabsdet = -logabsdet_x

# We need stabilize for det_of_inv and log_reciprocal
# and specialize for slogdet_specialization
rewritten_sign, rewritten_logabsdet = rewrite_graph(
[sign_inv, logabsdet_inv], include=["canonicalize", "stabilize", "specialize"]
)

expected_sign_opt, expected_logabsdet_opt = rewrite_graph(
[expected_sign, expected_logabsdet],
include=["canonicalize", "stabilize", "specialize"],
)

assert_equal_computations([rewritten_sign], [expected_sign_opt])
assert_equal_computations([rewritten_logabsdet], [expected_logabsdet_opt])
80 changes: 80 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5116,3 +5116,83 @@ def test_rewrite_does_not_apply(self):
original, include=("canonicalize", "stabilize", "specialize")
)
assert_equal_computations([rewritten], [original])


def test_log_reciprocal():
x = pt.dscalar("x")
out = pt.log(pt.reciprocal(x))
expected = -pt.log(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])


def test_sign_reciprocal():
x = pt.dscalar("x")
out = pt.sign(pt.reciprocal(x))
expected = pt.sign(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])


@pytest.mark.parametrize(
"build, expected_fn",
[
(lambda x: pt.log(3.0 / x), lambda x: pt.log(3.0) - pt.log(x)),
(lambda x: pt.log(x / 3.0), lambda x: pt.log(x) - pt.log(3.0)),
(lambda x: pt.log(1.0 / x), lambda x: -pt.log(x)),
],
ids=["pos_const_num", "pos_const_den", "one_over_x"],
)
def test_log_div_positive_constant(build, expected_fn):
x = pt.dscalar("x")
rewritten = rewrite_graph(
build(x), include=["canonicalize", "stabilize", "specialize"]
)
expected = rewrite_graph(
expected_fn(x), include=["canonicalize", "stabilize", "specialize"]
)
assert_equal_computations([rewritten], [expected])


def test_log_div_non_constant_not_rewritten():
x = pt.dscalar("x")
y = pt.dscalar("y")
out = pt.log(x / y)
rewritten = rewrite_graph(out, include=["canonicalize", "stabilize", "specialize"])
# No constant to peel off — graph should still contain a true_div.
nodes = [v.owner for v in ancestors([rewritten]) if v.owner]
assert any(
isinstance(getattr(node.op, "scalar_op", None), ps.TrueDiv) for node in nodes
)


@pytest.mark.parametrize(
"build, expected_fn",
[
(lambda x: pt.sign(3.0 / x), lambda x: pt.sign(x)),
(lambda x: pt.sign(-3.0 / x), lambda x: -pt.sign(x)),
(lambda x: pt.sign(x / 3.0), lambda x: pt.sign(x)),
(lambda x: pt.sign(x / -3.0), lambda x: -pt.sign(x)),
],
ids=["pos_num", "neg_num", "pos_den", "neg_den"],
)
def test_sign_div_constant(build, expected_fn):
x = pt.dscalar("x")
rewritten = rewrite_graph(
build(x), include=["canonicalize", "stabilize", "specialize"]
)
expected = rewrite_graph(
expected_fn(x), include=["canonicalize", "stabilize", "specialize"]
)
assert_equal_computations([rewritten], [expected])


def test_sign_div_non_constant_not_rewritten():
x = pt.dscalar("x")
y = pt.dscalar("y")
out = pt.sign(x / y)
rewritten = rewrite_graph(out, include=["canonicalize", "stabilize", "specialize"])
nodes = [v.owner for v in ancestors([rewritten]) if v.owner]
assert any(
isinstance(getattr(node.op, "scalar_op", None), ps.TrueDiv) for node in nodes
)
Loading