Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pytensor/tensor/rewriting/linalg/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytensor.tensor.linalg.constructors import BlockDiagonal
from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky
from pytensor.tensor.linalg.decomposition.lu import lu_factor
from pytensor.tensor.linalg.inverse import MatrixInverse
from pytensor.tensor.linalg.solvers.core import SolveBase
from pytensor.tensor.linalg.solvers.general import Solve, lu_solve
from pytensor.tensor.linalg.solvers.linear_control import (
Expand Down Expand Up @@ -163,6 +164,23 @@ def scalar_solve_to_division(fgraph, node):
return [new_out]


@register_stabilize
@node_rewriter([blockwise_of(SolveBase)])
def solve_of_inv_to_matmul(fgraph, node):
"""Replace solve(matrix_inverse(X), b) with X @ b.

If A = inv(X), then solve(A, b) finds x such that A @ x = b,
i.e., inv(X) @ x = b, so x = X @ b.
"""
A, b = node.inputs

match A.owner_op_and_inputs:
case (Blockwise(MatrixInverse()), X):
new_out = X @ b
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@register_canonicalize
@register_stabilize
@node_rewriter([blockwise_of(SolveBase)])
Expand Down
18 changes: 18 additions & 0 deletions tests/tensor/rewriting/linalg/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph import ancestors
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.scan.op import Scan
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.linalg.constructors import BlockDiagonal
Expand All @@ -28,6 +29,7 @@
scan_split_non_sequence_decomposition_and_solve,
)
from pytensor.tensor.type import matrix, tensor
from tests.unittest_tools import assert_equal_computations


def test_generic_solve_to_solve_triangular():
Expand Down Expand Up @@ -443,6 +445,22 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
np.testing.assert_allclose(resx0, resx1, rtol=rtol)


@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
X = pt.dmatrix("X")
b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")
out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)

# We include 'stabilize' because solve_of_inv_to_matmul is registered there.
# Note: rewrite_graph includes 'canonicalize' by default.
rewritten_out = rewrite_graph(out, include=["stabilize"])

# Verify the rewrite against stabilized 'X @ b' to ensure structural equality.
# stabilization lowers 'X @ b' (Matmul) to specific BLAS ops (like Dot).
expected = rewrite_graph(X @ b, include=["stabilize"])
assert_equal_computations([rewritten_out], [expected])


@pytest.mark.parametrize(
"b_ndim, solve_fn, expected_op, batch",
[
Expand Down
Loading