Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions xrspatial/tests/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,3 +1212,72 @@ def test_crop_nothing_to_crop():
assert result.shape == arr.shape
compare = arr == result.data
assert compare.all()


# ---------------------------------------------------------------------------
# Regression tests for #881: np.unique / np.isfinite must not materialise
# the full dask array.
# ---------------------------------------------------------------------------

@pytest.mark.skipif(not has_dask_array(), reason="dask.array not available")
def test_stats_does_not_materialise_dask_zones():
"""stats() with dask backend must never pass a dask array to np.unique."""
from unittest import mock

zones_np = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, np.nan, 3, 3]])
values_np = np.array([[0, 0, 1, 1, 2, 2, 3, np.inf],
[0, 0, 1, 1, 2, np.nan, 3, 0],
[np.inf, 0, 1, 1, 2, 2, 3, 3]])

zones = xr.DataArray(da.from_array(zones_np, chunks=(3, 4)), dims=['y', 'x'])
values = xr.DataArray(da.from_array(values_np, chunks=(3, 4)), dims=['y', 'x'])

_real_np_unique = np.unique

def _guarded_unique(a, *args, **kwargs):
if isinstance(a, da.Array):
raise AssertionError("np.unique called with a dask array — would materialise")
return _real_np_unique(a, *args, **kwargs)

with mock.patch("xrspatial.zonal.np.unique", side_effect=_guarded_unique):
result = stats(zones, values)

# dask path returns a lazy dask DataFrame; compute to verify correctness
if hasattr(result, 'compute'):
result = result.compute()
assert isinstance(result, pd.DataFrame)
assert len(result) > 0


@pytest.mark.skipif(not has_dask_array(), reason="dask.array not available")
def test_crosstab_does_not_materialise_dask_zones():
"""crosstab() with dask backend must never pass a dask array to np.unique."""
from unittest import mock

zones_np = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, np.nan, 3, 3]])
values_np = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 0, 1, 1, 2, np.nan, 3, 0],
[0, 0, 1, 1, 2, 2, 3, 3]])

zones = xr.DataArray(da.from_array(zones_np, chunks=(3, 4)), dims=['y', 'x'])
values = xr.DataArray(da.from_array(values_np, chunks=(3, 4)), dims=['y', 'x'])

_real_np_unique = np.unique

def _guarded_unique(a, *args, **kwargs):
if isinstance(a, da.Array):
raise AssertionError("np.unique called with a dask array — would materialise")
return _real_np_unique(a, *args, **kwargs)

with mock.patch("xrspatial.zonal.np.unique", side_effect=_guarded_unique):
result = crosstab(zones, values)

# dask path returns a lazy dask DataFrame; compute to verify correctness
if hasattr(result, 'compute'):
result = result.compute()
assert isinstance(result, pd.DataFrame)
assert len(result) > 0
46 changes: 38 additions & 8 deletions xrspatial/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,35 @@ class cupy(object):
TOTAL_COUNT = '_total_count'


def _unique_finite_zones(arr):
"""Sorted unique finite values from *arr* without full materialisation.

For dask arrays uses ``da.unique`` (per-chunk reduction) so the full
array is never pulled into RAM.
"""
if da is not None and isinstance(arr, da.Array):
uniq = da.unique(arr).compute()
return uniq[np.isfinite(uniq)]
return np.unique(arr[np.isfinite(arr)])


def _unique_finite_cats(arr, nodata_values):
"""Sorted unique values excluding NaN, Inf, and *nodata_values*.

Dask-safe: uses ``da.unique`` so the full array is never materialised.
"""
if da is not None and isinstance(arr, da.Array):
uniq = da.unique(arr).compute()
mask = np.isfinite(uniq)
if nodata_values is not None:
mask &= (uniq != nodata_values)
return uniq[mask]
mask = np.isfinite(arr)
if nodata_values is not None:
mask &= (arr != nodata_values)
return np.unique(arr[mask])


def _stats_count(data):
if isinstance(data, np.ndarray):
# numpy case
Expand Down Expand Up @@ -187,7 +216,7 @@ def _stats_dask_numpy(
) -> pd.DataFrame:

# find ids for all zones
unique_zones = np.unique(zones[np.isfinite(zones)])
unique_zones = _unique_finite_zones(zones)

select_all_zones = False
# selecte zones to do analysis
Expand All @@ -199,7 +228,10 @@ def _stats_dask_numpy(
values_blocks = values.to_delayed().ravel()

stats_dict = {}
stats_dict["zone"] = unique_zones # zone column
stats_dict["zone"] = da.from_delayed( # zone column
delayed(lambda x: x)(unique_zones),
shape=(np.nan,), dtype=unique_zones.dtype,
)

compute_sum_squares = False
compute_sum = False
Expand Down Expand Up @@ -287,7 +319,7 @@ def _stats_numpy(
) -> Union[pd.DataFrame, np.ndarray]:

# find ids for all zones
unique_zones = np.unique(zones[np.isfinite(zones)])
unique_zones = _unique_finite_zones(zones)
# selected zones to do analysis
if zone_ids is None:
zone_ids = unique_zones
Expand Down Expand Up @@ -670,9 +702,7 @@ def stats(
def _find_cats(values, cat_ids, nodata_values):
if len(values.shape) == 2:
# 2D case
unique_cats = np.unique(values.data[
np.isfinite(values.data) & (values.data != nodata_values)
])
unique_cats = _unique_finite_cats(values.data, nodata_values)
else:
# 3D case
unique_cats = values[values.dims[0]].data
Expand Down Expand Up @@ -756,7 +786,7 @@ def _crosstab_numpy(
) -> pd.DataFrame:

# find ids for all zones
unique_zones = np.unique(zones[np.isfinite(zones)])
unique_zones = _unique_finite_zones(zones)
# selected zones to do analysis
if zone_ids is None:
zone_ids = unique_zones
Expand Down Expand Up @@ -894,7 +924,7 @@ def _crosstab_dask_numpy(
agg: str,
):
# find ids for all zones
unique_zones = np.unique(zones[np.isfinite(zones)])
unique_zones = _unique_finite_zones(zones)
if zone_ids is None:
zone_ids = unique_zones
else:
Expand Down