Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pytensor/tensor/linalg/products.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import scipy.linalg as scipy_linalg

from pytensor import tensor as pt
Expand Down Expand Up @@ -126,9 +128,15 @@ def matrix_dot(*args):
:math:`A_0 \cdot A_1 \cdot A_2 \cdot .. \cdot A_N`.

"""
warnings.warn(
"matrix_dot is deprecated and will be removed in future version.",
DeprecationWarning,
stacklevel=2,
)

rval = args[0]
for a in args[1:]:
rval = ptm.dot(rval, a)
rval = ptm.matmul(rval, a)
Comment thread
jessegrabowski marked this conversation as resolved.
return rval


Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytensor.tensor.rewriting.linalg.decomposition
import pytensor.tensor.rewriting.linalg.inverse
import pytensor.tensor.rewriting.linalg.products
import pytensor.tensor.rewriting.linalg.reassociate_matmul
import pytensor.tensor.rewriting.linalg.solvers
import pytensor.tensor.rewriting.linalg.summary
import pytensor.tensor.rewriting.linalg.utils
5 changes: 1 addition & 4 deletions pytensor/tensor/rewriting/linalg/products.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import ExtractDiag, concatenate, diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg.constructors import BlockDiagonal
Expand Down
Loading
Loading