Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,6 @@ docs/examples_ipynb/
# Envs
.pixi/
.venv/
*.pem
*.db
array-api-tests/
7 changes: 4 additions & 3 deletions sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def _einsum_single(lhs, rhs, operand):

if lhs == rhs:
if not rhs:
# ensure scalar output
# full contraction — return 0-D array per the Array API standard
return operand.sum()
return operand

Expand Down Expand Up @@ -1390,8 +1390,9 @@ def _einsum_single(lhs, rhs, operand):
new_data = operand.data

if not rhs:
# scalar output - match numpy behaviour by not wrapping as array
return new_data.sum()
# full contraction — return 0-D COO array per the Array API standard
data = np.asarray(new_data.sum())
return COO.from_numpy(data)

return to_output_format(COO(new_coords, new_data, shape=new_shape, has_duplicates=True))

Expand Down
20 changes: 16 additions & 4 deletions sparse/numba_backend/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
check_consistent_fill_value,
check_zero_fill_value,
is_unsigned_dtype,
isscalar,
normalize_axis,
)

Expand Down Expand Up @@ -416,6 +415,19 @@ def nanmean(x, axis=None, keepdims=False, dtype=None, out=None):
return (num / den).astype(dtype if dtype is not None else x.dtype)


def _contains_nan(ar):
"""Check if a SparseArray or scalar contains any NaN values.
Checks dtype first (fast), then fill_value, then data (slow).
"""
if isinstance(ar, SparseArray):
if not np.issubdtype(ar.dtype, np.floating):
return False
if ar.nnz != ar.size and np.isnan(ar.fill_value):
return True
return np.isnan(ar.data).any()
return np.isnan(ar)


def nanmax(x, axis=None, keepdims=False, dtype=None, out=None):
"""
Maximize along the given axes, skipping `NaN` values. Uses all axes by default.
Expand Down Expand Up @@ -446,7 +458,7 @@ def nanmax(x, axis=None, keepdims=False, dtype=None, out=None):

ar = x.reduce(np.fmax, axis=axis, keepdims=keepdims, dtype=dtype)

if (isscalar(ar) and np.isnan(ar)) or np.isnan(ar.data).any():
if _contains_nan(ar):
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=1)

return ar
Expand Down Expand Up @@ -482,7 +494,7 @@ def nanmin(x, axis=None, keepdims=False, dtype=None, out=None):

ar = x.reduce(np.fmin, axis=axis, keepdims=keepdims, dtype=dtype)

if (isscalar(ar) and np.isnan(ar)) or np.isnan(ar.data).any():
if _contains_nan(ar):
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=1)

return ar
Expand Down Expand Up @@ -901,7 +913,7 @@ def diagonalize(a, axis=0):
>>> a = sparse.random((3, 3, 3, 3, 3), density=0.3)
>>> a_diag = sparse.diagonalize(a, axis=2)
>>> (sparse.diagonal(a_diag, axis1=2, axis2=5) == a.transpose([0, 1, 3, 4, 2])).all()
np.True_
<COO: shape=(), dtype=bool, nnz=0, fill_value=True>

Returns
-------
Expand Down
18 changes: 9 additions & 9 deletions sparse/numba_backend/_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def reduce(self, method, axis=(0,), keepdims=False, **kwargs):
axis = (axis,)
out = self._reduce_calc(method, axis, keepdims, **kwargs)
if len(out) == 1:
return out[0]
return out[0] if isinstance(out[0], SparseArray) else type(self).from_numpy(np.array(out[0]))
data, counts, axis, n_cols, arr_attrs = out
result_fill_value = self.fill_value
if reduce_super_ufunc is None:
Expand All @@ -422,7 +422,9 @@ def reduce(self, method, axis=(0,), keepdims=False, **kwargs):
out = out.reshape(shape)

if out.ndim == 0:
return out[()]
# Return a 0-D array per the Array API standard.
# The element value becomes the fill_value (nnz=0 is correct for 0-D).
return type(self).from_numpy(out.todense())

return out

Expand Down Expand Up @@ -689,7 +691,7 @@ def mean(self, axis=None, keepdims=False, dtype=None, out=None):
mean along all axes.

>>> s.mean()
np.float64(0.5)
Comment thread
Abineshabee marked this conversation as resolved.
<COO: shape=(), dtype=float64, nnz=0, fill_value=0.5>
"""

if axis is None:
Expand All @@ -709,10 +711,8 @@ def mean(self, axis=None, keepdims=False, dtype=None, out=None):

num = self.sum(axis=axis, keepdims=keepdims, dtype=inter_dtype)

if num.ndim:
out = np.true_divide(num, den, casting="unsafe")
return out.astype(dtype) if out.dtype != dtype else out
return np.divide(num, den, dtype=dtype, out=out)
out = np.true_divide(num, den, casting="unsafe")
return out.astype(dtype) if out.dtype != dtype else out

def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
"""
Expand Down Expand Up @@ -769,7 +769,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
variance along all axes.

>>> s.var()
np.float64(0.5)
<COO: shape=(), dtype=float64, nnz=0, fill_value=0.5>
"""
axis = normalize_axis(axis, self.ndim)

Expand Down Expand Up @@ -803,7 +803,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):

ret = ret[...]
np.divide(ret, rcount, out=ret, casting="unsafe")
return ret[()]
return ret

def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
"""
Expand Down
66 changes: 66 additions & 0 deletions sparse/numba_backend/tests/test_array_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,69 @@ def test_asarray(self, input, dtype, format):

if isinstance(input, SparseArray):
assert sparse.asarray(input).__class__ is input.__class__


class TestArrayAPIReductions:
"""
Array API standard compliance: reductions over the entire array must return
a zero-dimensional array, not a NumPy scalar.

See: https://github.com/pydata/sparse/issues/921
"""

@pytest.mark.parametrize("format", ["coo", "gcxs"])
@pytest.mark.parametrize(
"fn, expected",
[
(sparse.sum, 2.0),
(sparse.max, 1.0),
(sparse.min, 0.0),
(sparse.prod, 0.0),
(sparse.mean, 0.5),
],
)
def test_full_reduction_returns_0d_array(self, fn, expected, format):
x = sparse.asarray(np.eye(2), format=format)
result = fn(x)
assert result.ndim == 0, f"{fn.__name__}() over entire array returned ndim={result.ndim}, expected 0-D array"
assert isinstance(result, SparseArray), (
f"{fn.__name__}() returned {type(result).__name__}, expected a SparseArray"
)
assert abs(float(result) - expected) < 1e-9, f"{fn.__name__}() returned {float(result)}, expected {expected}"

@pytest.mark.parametrize("fn", [sparse.any, sparse.all])
def test_boolean_reduction_returns_0d_array(self, fn):
x = sparse.asarray(np.eye(2), format="coo")
result = fn(x)
assert result.ndim == 0, f"{fn.__name__}() returned ndim={result.ndim}, expected 0-D array"
assert isinstance(result, SparseArray), (
f"{fn.__name__}() returned {type(result).__name__}, expected a SparseArray"
)

def test_partial_reduction_still_returns_nd_array(self):
"""Axis-specific reductions must still return N-D sparse arrays."""
x = sparse.asarray(np.eye(2), format="coo")

result_ax0 = sparse.sum(x, axis=0)
assert result_ax0.shape == (2,), f"Expected shape (2,), got {result_ax0.shape}"
assert isinstance(result_ax0, SparseArray)

result_ax1 = sparse.sum(x, axis=1)
assert result_ax1.shape == (2,), f"Expected shape (2,), got {result_ax1.shape}"
assert isinstance(result_ax1, SparseArray)

def test_keepdims_full_reduction(self):
"""keepdims=True must preserve all dimensions as size-1."""
x = sparse.asarray(np.eye(2), format="coo")
result = sparse.sum(x, keepdims=True)
assert result.shape == (1, 1), f"Expected shape (1, 1), got {result.shape}"
assert isinstance(result, SparseArray)

@pytest.mark.parametrize("format", ["coo", "gcxs"])
def test_1d_full_reduction_returns_0d_array(self, format):
"""1-D input fully reduced must also give a 0-D array."""
x = sparse.asarray(np.array([1.0, 2.0, 3.0]), format=format)
result = sparse.sum(x)
assert result.ndim == 0, f"Expected 0-D array, got ndim={result.ndim}"
assert isinstance(result, SparseArray)
assert abs(float(result) - 6.0) < 1e-9
8 changes: 4 additions & 4 deletions sparse/numba_backend/tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def test_einsum(subscripts, density):
numpy_out = np.einsum(subscripts, *(s.todense() for s in arrays))

if not numpy_out.shape:
# scalar output
assert np.allclose(numpy_out, sparse_out)
# scalar output — sparse_out is a 0-D COO per the Array API standard
assert np.allclose(numpy_out, sparse_out.todense())
else:
# array output
assert np.allclose(numpy_out, sparse_out.todense())
Expand All @@ -108,8 +108,8 @@ def test_einsum_nosubscript(input, density):
numpy_out = np.einsum(*(s.todense() for s in arrays), *input)

if not numpy_out.shape:
# scalar output
assert np.allclose(numpy_out, sparse_out)
# scalar output — sparse_out is a 0-D COO per the Array API standard
assert np.allclose(numpy_out, sparse_out.todense())
else:
# array output
assert np.allclose(numpy_out, sparse_out.todense())
Expand Down
Loading