Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/xarray_einstats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
eigvals,
eigvalsh,
inv,
matmul,
matrix_power,
matrix_rank,
matrix_transpose,
norm,
pinv,
qr,
slogdet,
solve,
Expand All @@ -31,14 +33,18 @@ class LinAlgAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def matrix_transpose(self, dims):
def matrix_transpose(self, dims=None):
"""Call :func:`xarray_einstats.linalg.matrix_transpose` on this DataArray."""
return matrix_transpose(self._obj, dims=dims)

def matrix_power(self, n, dims=None, **kwargs):
"""Call :func:`xarray_einstats.linalg.matrix_power` on this DataArray."""
return matrix_power(self._obj, n, dims=dims, **kwargs)

def matmul(self, other, dims=None, **kwargs):
"""Call :func:`xarray_einstats.linalg.matmul` with this DataArray as ``a/da``."""
return matmul(self._obj, other, dims=dims, **kwargs)

def cholesky(self, dims=None, **kwargs):
"""Call :func:`xarray_einstats.linalg.cholesky` on this DataArray."""
return cholesky(self._obj, dims=dims, **kwargs)
Expand Down Expand Up @@ -120,6 +126,10 @@ def inv(self, dims=None, **kwargs):
"""Call :func:`xarray_einstats.linalg.inv` on this DataArray."""
return inv(self._obj, dims=dims, **kwargs)

def pinv(self, dims=None, **kwargs):
"""Call :func:`xarray_einstats.linalg.pinv` on this DataArray."""
return pinv(self._obj, dims=dims, **kwargs)


@xr.register_dataarray_accessor("einops")
class EinopsAccessor:
Expand Down
2 changes: 2 additions & 0 deletions src/xarray_einstats/accessors.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from .linalg import (
eigvals,
eigvalsh,
inv,
matmul,
matrix_power,
matrix_rank,
matrix_transpose,
Expand All @@ -28,6 +29,7 @@ class LinAlgAccessor:
def __init__(self, xarray_obj: Incomplete) -> None: ...
def matrix_transpose(self, dims: Incomplete) -> None: ...
def matrix_power(self, n: Incomplete, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ...
def matmul(self, other: Incomplete, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ...
def cholesky(self, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ...
def qr(
self,
Expand Down
23 changes: 19 additions & 4 deletions src/xarray_einstats/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,17 @@ def matmul(da, db, dims=None, *, out_append="2", **kwargs):
return matmul_aux


def matrix_transpose(da, dims):
def matrix_transpose(da, dims=None):
"""Transpose the underlying matrix without modifying the dimensions.

This convenience function uses :meth:`~xarray.DataArray.swap_dims` followed
This convenience function uses :meth:`~xarray.DataArray.rename` followed
by :meth:`~xarray.DataArray.transpose` to get the equivalent of a matrix transposition.

Parameters
----------
da : DataArray
Input DataArray
dims : list of str
dims : list of str, optional
Matrix dimensions

Returns
Expand All @@ -455,7 +455,22 @@ def matrix_transpose(da, dims):
if dims is None:
dims = _attempt_default_dims("matrix_transpose", da.dims)
dim1, dim2 = dims
return da.swap_dims({dim1: dim2, dim2: dim1}).transpose(..., *dims)
rename_dict = {dim1: dim2, dim2: dim1}

if (
dim1 in da.indexes
and dim2 in da.indexes
and len(da.indexes[dim1].names) == len(da.indexes[dim2].names)
and len(da.indexes[dim1].names) > 1
):
for sub_dim1, sub_dim2 in zip(da.indexes[dim1].names, da.indexes[dim2].names):
rename_dict[sub_dim1] = sub_dim2
rename_dict[sub_dim2] = sub_dim1

da_transposed = da.rename(rename_dict).transpose(..., *dims)

# Purely cosmetic change to preserve order of coordinates in the output
return da_transposed.assign_coords({k: da_transposed.coords[k] for k in da.coords})


def matrix_power(da, n, dims=None, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion src/xarray_einstats/linalg.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def matmul(
out_append: str = ...,
**kwargs: Incomplete,
) -> xarray.DataArray: ...
def matrix_transpose(da: xarray.DataArray, dims: list[str]) -> xarray.DataArray: ...
def matrix_transpose(da: xarray.DataArray, dims: list[str] | None = ...) -> xarray.DataArray: ...
def matrix_power(
da: xarray.DataArray, n: int, dims: Sequence[Hashable] | None = ..., **kwargs: Incomplete
) -> xarray.DataArray: ...
Expand Down
4 changes: 4 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def test_pinv_dataarray_tol(self, matrices, kind):
def test_transpose(self, hermitian):
assert_equal(hermitian, matrix_transpose(hermitian, dims=("dim", "dim2")))

def test_transpose_multiindex(self, matrices):
stacked = matrices.stack(batch_experiment=("batch", "experiment"), dim_dim2=("dim", "dim2"))
matrix_transpose(stacked, dims=("batch_experiment", "dim_dim2"))

def test_matrix_power(self, matrices):
out = matrix_power(matrices, 2, dims=("dim", "dim2"))
assert out.shape == matrices.shape
Expand Down