Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
776bc5a
use cumsum from flox
Illviljan Dec 6, 2025
ae27632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a5f9326
Update groupby.py
Illviljan Dec 6, 2025
50ccca4
Update groupby.py
Illviljan Dec 6, 2025
f55531e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
06ac372
Update groupby.py
Illviljan Dec 6, 2025
31244e6
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
dd47536
Update groupby.py
Illviljan Dec 6, 2025
e867f12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
88e0ebc
Update groupby.py
Illviljan Dec 6, 2025
181d4a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a82ec39
use apply_ufunc for dataset and dataarray handling
Illviljan Dec 6, 2025
6c6abed
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
24c3f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d8d0eaa
Update groupby.py
Illviljan Dec 6, 2025
55ff46a
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
33d1360
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
c97ae98
sync protocols with each other
Illviljan Dec 6, 2025
06b52ae
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
84f9b44
typing
Illviljan Dec 6, 2025
2978877
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
0a9adee
add dataset and version requirement
Illviljan Dec 6, 2025
ae9a3d8
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
c056d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d4873b9
Update _aggregations.py
Illviljan Dec 6, 2025
21cbde2
Update xarray/core/groupby.py
Illviljan Dec 6, 2025
4aebc47
Update groupby.py
Illviljan Dec 6, 2025
f4cab24
Update groupby.py
Illviljan Dec 6, 2025
23d9d50
Update groupby.py
Illviljan Dec 6, 2025
9b64db2
Update generate_aggregations.py
Illviljan Dec 6, 2025
928b158
Renove workaround in test
Illviljan Dec 7, 2025
130f98e
Update _aggregations.py
Illviljan Dec 7, 2025
5a3e754
Update _aggregations.py
Illviljan Dec 7, 2025
d912cda
Update test_groupby.py
Illviljan Dec 7, 2025
3bc8dc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2025
ec8ffd6
clean ups
Illviljan Dec 7, 2025
b0cf8c4
Merge branch 'main' into cumsum_flox
Illviljan Dec 7, 2025
07a4d35
Add expected groups, add options
Illviljan Dec 8, 2025
d0f7ed2
Update groupby.py
Illviljan Dec 8, 2025
098be30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2025
87d5f77
expeced_groups not supported in groupby_scan
Illviljan Dec 8, 2025
16c93ea
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 8, 2025
dfe269a
Update _aggregations.py
Illviljan Dec 9, 2025
b2c3d51
Update _aggregations.py
Illviljan Dec 9, 2025
e28f458
Update generate_aggregations.py
Illviljan Dec 9, 2025
55a36ab
Update _aggregations.py
Illviljan Dec 9, 2025
ff531e1
Update _aggregations.py
Illviljan Dec 9, 2025
43aad2e
Update _aggregations.py
Illviljan Dec 9, 2025
8dfcc56
Update _aggregations.py
Illviljan Dec 9, 2025
9dac0a4
Update generate_aggregations.py
Illviljan Dec 9, 2025
0ba3504
Update _aggregations.py
Illviljan Dec 9, 2025
da2a3e3
Update _aggregations.py
Illviljan Dec 9, 2025
95e6fd3
Update _aggregations.py
Illviljan Dec 9, 2025
7d358b0
Update _aggregations.py
Illviljan Dec 9, 2025
f4fe7a0
Update _aggregations.py
Illviljan Dec 9, 2025
74f1073
Update _aggregations.py
Illviljan Dec 9, 2025
50f6209
Update _aggregations.py
Illviljan Dec 9, 2025
87675b2
Update _aggregations.py
Illviljan Dec 9, 2025
9aee62e
Update _aggregations.py
Illviljan Dec 9, 2025
02ee023
Update _aggregations.py
Illviljan Dec 9, 2025
82557c4
Update _aggregations.py
Illviljan Dec 9, 2025
9721574
Update _aggregations.py
Illviljan Dec 9, 2025
e1fba81
Update _aggregations.py
Illviljan Dec 9, 2025
5137fd8
Update _aggregations.py
Illviljan Dec 9, 2025
59a7f38
Update _aggregations.py
Illviljan Dec 9, 2025
7f519f0
Update _aggregations.py
Illviljan Dec 9, 2025
c4f5f83
Update _aggregations.py
Illviljan Dec 9, 2025
bf5197d
Update _aggregations.py
Illviljan Dec 9, 2025
5563600
Update _aggregations.py
Illviljan Dec 9, 2025
510300d
Update _aggregations.py
Illviljan Dec 9, 2025
5fe07df
Update _aggregations.py
Illviljan Dec 9, 2025
293cc1f
Update _aggregations.py
Illviljan Dec 9, 2025
d9f694c
Update _aggregations.py
Illviljan Dec 9, 2025
c9814db
Update _aggregations.py
Illviljan Dec 10, 2025
6ed0f99
Update _aggregations.py
Illviljan Dec 10, 2025
43a827d
Update test_groupby.py
Illviljan Dec 10, 2025
8d65562
Update test_groupby.py
Illviljan Dec 10, 2025
d19bbca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
acf4022
Update test_groupby.py
Illviljan Dec 10, 2025
f263da6
Update generate_aggregations.py
Illviljan Dec 10, 2025
e56d0b8
Update test_groupby.py
Illviljan Dec 10, 2025
8cbfd9d
Merge branch 'main' into cumsum_flox
Illviljan Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions xarray/core/_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6647,6 +6647,13 @@ def _flox_reduce(
) -> DataArray:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
**kwargs: Any,
) -> DataArray:
raise NotImplementedError()

def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -7904,13 +7911,27 @@ def cumsum(
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down
117 changes: 101 additions & 16 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from packaging.version import Version

from xarray.computation import ops
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.computation.arithmetic import (
DataArrayGroupbyArithmetic,
DatasetGroupbyArithmetic,
Expand Down Expand Up @@ -1028,6 +1029,26 @@ def _maybe_unstack(self, obj):

return obj

def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]:
parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(self._original_obj.dims)
else:
parsed_dim = tuple(dim)

return parsed_dim

def _flox_reduce(
self,
dim: Dims,
Expand Down Expand Up @@ -1088,22 +1109,7 @@ def _flox_reduce(
# set explicitly to avoid unnecessarily accumulating count
kwargs["min_count"] = 0

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(obj.dims)
else:
parsed_dim = tuple(dim)
parsed_dim = self._parse_dim(dim)

# Do this so we raise the same error message whether flox is present or not.
# Better to control it here than in flox.
Expand Down Expand Up @@ -1202,6 +1208,85 @@ def _flox_reduce(

return result

def _flox_scan(
self,
dim: Dims,
*,
func: str,
keep_attrs: bool | None = None,
skipna: bool | None = None,
**kwargs: Any,
) -> DataArray:
from flox import groupby_scan

obj = self._original_obj

if skipna or (
skipna is None and isinstance(func, str) and obj.dtype.kind in "cfO"
):
if "nan" not in func and func not in ["all", "any", "count"]:
func = f"nan{func}"

# if keep_attrs is None:
# keep_attrs = _get_keep_attrs(default=True)

parsed_dim = self._parse_dim(dim)

axis_ = obj.get_axis_num(parsed_dim)
axis = (axis_,) if isinstance(axis_, int) else axis_
codes = tuple(g.codes for g in self.groupers)
# g = groupby_scan(
# obj.data,
# *codes,
# func=func,
# expected_groups=None,
# axis=axis,
# dtype=None,
# method=None,
# engine=None,
# )
# result = obj.copy(data=g)

# return result

actual = apply_ufunc(
groupby_scan,
obj,
*codes,
# input_core_dims=input_core_dims,
# for xarray's test_groupby_duplicate_coordinate_labels
# exclude_dims=set(dim_tuple),
# output_core_dims=[output_core_dims],
dask="allowed",
# dask_gufunc_kwargs=dict(
# output_sizes=output_sizes,
# output_dtypes=[dtype] if dtype is not None else None,
# ),
keep_attrs=(
_get_keep_attrs(default=True) if keep_attrs is None else keep_attrs
),
kwargs=dict(
func=func,
expected_groups=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the same as _flox_reduce. This is an important optimization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expected_groups is not supported in groupby_scan. For a future PR I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh dang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look and remember why this. What do we do for group 0 when a user says grouped_scan(np.array([1, 2, 3], by=[0, 1, 2], expected_groups=[1, 2])?

Copy link
Contributor Author

@Illviljan Illviljan Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking it should be np.nan (or fill_value) for groups missing in expected groups.

An analog could be

  • groupby_reduce "uses" __getitem__ to mask missing groups.
  • groupby_scan will have to "use" np.where(mask, np.nan) to continue masking but with the same shape.

My expected result:

import flox
import numpy as np


# groupby_reduce omits 0:
flox.groupby_reduce(
    np.array([1, 2, 3]), [0, 1, 2], func="sum", expected_groups=[1, 2]
)
# (array([2, 3]), array([1, 2]))

flox.groupby_scan(
    np.array([1, 2, 3]), [0, 1, 2], func="cumsum", expected_groups=[1, 2]
)
# array([np.nan, 2, 3])

axis=axis,
dtype=None,
method=None,
engine=None,
),
)

return actual

# xarray_reduce(
# obj.drop_vars(non_numeric.keys()),
# *codes,
# dim=parsed_dim,
# expected_groups=expected_groups,
# isbin=False,
# keep_attrs=keep_attrs,
# **kwargs,
# )

def fillna(self, value: Any) -> T_Xarray:
"""Fill missing values in this object by group.

Expand Down
Loading