Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/xarray_einstats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Stats, linear algebra and einops for xarray."""

from __future__ import annotations
from contextlib import contextmanager
from collections.abc import Iterable

import numpy as np
import xarray as xr
Expand All @@ -9,6 +11,7 @@
from .accessors import LinAlgAccessor, EinopsAccessor

__all__ = [
"default_linalg_dims",
"einsum",
"einsum_path",
"matmul",
Expand Down Expand Up @@ -188,3 +191,38 @@ def ones_ref(*args, dims, dtype=None):
empty_ref, zeros_ref
"""
return _create_ref(*args, dims=dims, np_creator=np.ones, dtype=dtype)


@contextmanager
def default_linalg_dims(func_or_dims):
"""Context manager to temporarily set the default dimensions for linalg functions.

Safer alternative to monkey patching the `get_default_dims` function in `linalg` module,
as it ensures that the original function is restored even if an error occurs within the context.

Parameters
----------
func_or_dims : callable or iterable
If a callable is provided, it should take the same arguments as `get_default_dims`
and return the default dimensions based on those arguments.
If an iterable is provided, it will be used as the default dimensions
regardless of the input arguments.

Yields
------
None
"""
from xarray_einstats import linalg

original_get_default_dims = linalg.get_default_dims

def func(*args):
if isinstance(func_or_dims, Iterable):
return func_or_dims
return func_or_dims(*args)

linalg.get_default_dims = func
try:
yield
finally:
linalg.get_default_dims = original_get_default_dims
7 changes: 7 additions & 0 deletions src/xarray_einstats/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable, Sequence
from contextlib import contextmanager
from typing import Any, Callable, Generator

import numpy as np
import xarray
Expand All @@ -13,6 +15,7 @@ from .accessors import EinopsAccessor, LinAlgAccessor
from .linalg import einsum, einsum_path, matmul

__all__ = [
"default_linalg_dims",
"einsum",
"einsum_path",
"matmul",
Expand Down Expand Up @@ -52,3 +55,7 @@ def ones_ref(
dims: Sequence[Hashable],
dtype: np.typing.DTypeLike | None = ...,
) -> xarray.DataArray: ...
@contextmanager
def default_linalg_dims(
func_or_dims: Callable | Iterable,
) -> Generator[None, Any, None]: ...