Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):

C = Mode("c", "fast_run")
CVM = Mode("cvm", "fast_run")
VM = (Mode("vm", "fast_run"),)
VM = Mode("vm", "fast_run")

NUMBA = Mode(
NumbaLinker(),
Expand Down
61 changes: 35 additions & 26 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
copy_stack_trace,
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr
from pytensor.tensor.basic import (
AllocDiag,
Expand Down Expand Up @@ -132,40 +131,46 @@ def transinv_to_invtrans(fgraph, node):


@register_stabilize
@node_rewriter([Dot])
@node_rewriter([Dot, _matmul])
def inv_as_solve(fgraph, node):
"""
This utilizes a boolean `symmetric` tag on the matrices.
"""
if isinstance(node.op, Dot):
l, r = node.inputs
if (
l.owner
and isinstance(l.owner.op, Blockwise)
and isinstance(l.owner.op.core_op, MatrixInverse)
):
return [solve(l.owner.inputs[0], r)]
if (
r.owner
and isinstance(r.owner.op, Blockwise)
and isinstance(r.owner.op.core_op, MatrixInverse)
):
x = r.owner.inputs[0]
if getattr(x.tag, "symmetric", None) is True:
return [solve(x, (l.mT)).mT]
else:
return [solve((x.mT), (l.mT)).mT]
l, r = node.inputs

# inv(A) @ B → solve(A, B)
if (
l.owner
and isinstance(l.owner.op, Blockwise)
and isinstance(l.owner.op.core_op, MatrixInverse)
):
A = l.owner.inputs[0]
B = r

return [solve(A, B)]

# A @ inv(B) → solve(B.T, A.T).T ← THIS is the critical fix
if (
r.owner
and isinstance(r.owner.op, Blockwise)
and isinstance(r.owner.op.core_op, MatrixInverse)
):
B = r.owner.inputs[0]
A = l

return [solve(B.T, A.T).T]

return None


@register_stabilize
@register_canonicalize
@node_rewriter([blockwise_of(OpPattern(Solve, assume_a="gen"))])
@node_rewriter([blockwise_of(Solve)])
def generic_solve_to_solve_triangular(fgraph, node):
"""
If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.

"""
if node.op.core_op.assume_a != "gen":
return None
A, b = node.inputs # result is the solution to Ax=b
if (
A.owner
Expand Down Expand Up @@ -194,12 +199,14 @@ def generic_solve_to_solve_triangular(fgraph, node):


@register_specialize
@node_rewriter([blockwise_of(OpPattern(SolveBase, b_ndim=1))])
@node_rewriter([blockwise_of(SolveBase)])
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T

`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
"""
if node.op.core_op.b_ndim != 1:
return None
core_op = node.op.core_op
[a, b] = node.inputs

Expand Down Expand Up @@ -251,11 +258,13 @@ def no_transpose_symmetric(fgraph, node):


@register_stabilize
@node_rewriter([blockwise_of(OpPattern(Solve, b_ndim=2))])
@node_rewriter([blockwise_of(Solve)])
def psd_solve_with_chol(fgraph, node):
"""
This utilizes a boolean `psd` tag on matrices.
"""
if node.op.core_op.b_ndim != 2:
return None
A, b = node.inputs # result is the solution to Ax=b
if getattr(A.tag, "psd", None) is True:
L = cholesky(A)
Expand Down
40 changes: 40 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,3 +1128,43 @@ def solve_op_in_graph(graph):
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)


def test_simplify_transpose_solve_transpose():
"""
Test that transpose(solve(transpose(A), transpose(B))) is simplified to solve(A, B).

This verifies:
1. The rewrite removes unnecessary transpose operations.
2. The optimized graph still produces correct numerical results.
"""

import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.tensor.blockwise import Blockwise

A = pt.matrix("A")
B = pt.matrix("B")

expr = pt.matmul(A, pt.linalg.inv(B))

f = pytensor.function([A, B], expr, mode="FAST_RUN")

topo = f.maker.fgraph.toposort()

# Ensure MatrixInverse has been eliminated
assert not any(
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, MatrixInverse)
for node in topo
)

# Numeric correctness test
A_val = np.array([[1.0, 2.0], [3.0, 4.0]])
B_val = np.array([[5.0, 6.0], [7.0, 8.0]])

result = f(A_val, B_val)
expected = A_val @ np.linalg.inv(B_val)

np.testing.assert_allclose(result, expected)