-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Use cumsum from flox #10987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Use cumsum from flox #10987
Changes from 14 commits
776bc5a
ae27632
a5f9326
50ccca4
f55531e
06ac372
31244e6
dd47536
e867f12
88e0ebc
181d4a3
a82ec39
6c6abed
24c3f1d
d8d0eaa
55ff46a
33d1360
c97ae98
06b52ae
84f9b44
2978877
0a9adee
ae9a3d8
c056d1f
d4873b9
21cbde2
4aebc47
f4cab24
23d9d50
9b64db2
928b158
130f98e
5a3e754
d912cda
3bc8dc7
ec8ffd6
b0cf8c4
07a4d35
d0f7ed2
098be30
87d5f77
16c93ea
dfe269a
b2c3d51
e28f458
55a36ab
ff531e1
43aad2e
8dfcc56
9dac0a4
0ba3504
da2a3e3
95e6fd3
7d358b0
f4fe7a0
74f1073
50f6209
87675b2
9aee62e
02ee023
82557c4
9721574
e1fba81
5137fd8
59a7f38
7f519f0
c4f5f83
bf5197d
5563600
510300d
5fe07df
293cc1f
d9f694c
c9814db
6ed0f99
43a827d
8d65562
d19bbca
acf4022
f263da6
e56d0b8
8cbfd9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| from packaging.version import Version | ||
|
|
||
| from xarray.computation import ops | ||
| from xarray.computation.apply_ufunc import apply_ufunc | ||
| from xarray.computation.arithmetic import ( | ||
| DataArrayGroupbyArithmetic, | ||
| DatasetGroupbyArithmetic, | ||
|
|
@@ -1028,6 +1029,26 @@ def _maybe_unstack(self, obj): | |
|
|
||
| return obj | ||
|
|
||
| def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]: | ||
| parsed_dim: tuple[Hashable, ...] | ||
| if isinstance(dim, str): | ||
| parsed_dim = (dim,) | ||
| elif dim is None: | ||
| parsed_dim_list = list() | ||
| # preserve order | ||
| for dim_ in itertools.chain( | ||
| *(grouper.codes.dims for grouper in self.groupers) | ||
| ): | ||
| if dim_ not in parsed_dim_list: | ||
| parsed_dim_list.append(dim_) | ||
| parsed_dim = tuple(parsed_dim_list) | ||
| elif dim is ...: | ||
| parsed_dim = tuple(self._original_obj.dims) | ||
| else: | ||
| parsed_dim = tuple(dim) | ||
|
|
||
| return parsed_dim | ||
|
|
||
| def _flox_reduce( | ||
| self, | ||
| dim: Dims, | ||
|
|
@@ -1088,22 +1109,7 @@ def _flox_reduce( | |
| # set explicitly to avoid unnecessarily accumulating count | ||
| kwargs["min_count"] = 0 | ||
|
|
||
| parsed_dim: tuple[Hashable, ...] | ||
| if isinstance(dim, str): | ||
| parsed_dim = (dim,) | ||
| elif dim is None: | ||
| parsed_dim_list = list() | ||
| # preserve order | ||
| for dim_ in itertools.chain( | ||
| *(grouper.codes.dims for grouper in self.groupers) | ||
| ): | ||
| if dim_ not in parsed_dim_list: | ||
| parsed_dim_list.append(dim_) | ||
| parsed_dim = tuple(parsed_dim_list) | ||
| elif dim is ...: | ||
| parsed_dim = tuple(obj.dims) | ||
| else: | ||
| parsed_dim = tuple(dim) | ||
| parsed_dim = self._parse_dim(dim) | ||
|
|
||
| # Do this so we raise the same error message whether flox is present or not. | ||
| # Better to control it here than in flox. | ||
|
|
@@ -1202,6 +1208,85 @@ def _flox_reduce( | |
|
|
||
| return result | ||
|
|
||
| def _flox_scan( | ||
| self, | ||
| dim: Dims, | ||
| *, | ||
| func: str, | ||
| keep_attrs: bool | None = None, | ||
| skipna: bool | None = None, | ||
| **kwargs: Any, | ||
| ) -> DataArray: | ||
| from flox import groupby_scan | ||
|
|
||
| obj = self._original_obj | ||
|
|
||
| if skipna or ( | ||
| skipna is None and isinstance(func, str) and obj.dtype.kind in "cfO" | ||
| ): | ||
| if "nan" not in func and func not in ["all", "any", "count"]: | ||
| func = f"nan{func}" | ||
|
|
||
| # if keep_attrs is None: | ||
| # keep_attrs = _get_keep_attrs(default=True) | ||
|
|
||
| parsed_dim = self._parse_dim(dim) | ||
|
|
||
| axis_ = obj.get_axis_num(parsed_dim) | ||
| axis = (axis_,) if isinstance(axis_, int) else axis_ | ||
| codes = tuple(g.codes for g in self.groupers) | ||
| # g = groupby_scan( | ||
| # obj.data, | ||
| # *codes, | ||
| # func=func, | ||
| # expected_groups=None, | ||
| # axis=axis, | ||
| # dtype=None, | ||
| # method=None, | ||
| # engine=None, | ||
| # ) | ||
| # result = obj.copy(data=g) | ||
|
|
||
| # return result | ||
|
|
||
| actual = apply_ufunc( | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| groupby_scan, | ||
| obj, | ||
| *codes, | ||
| # input_core_dims=input_core_dims, | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # for xarray's test_groupby_duplicate_coordinate_labels | ||
| # exclude_dims=set(dim_tuple), | ||
| # output_core_dims=[output_core_dims], | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dask="allowed", | ||
| # dask_gufunc_kwargs=dict( | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # output_sizes=output_sizes, | ||
| # output_dtypes=[dtype] if dtype is not None else None, | ||
| # ), | ||
| keep_attrs=( | ||
| _get_keep_attrs(default=True) if keep_attrs is None else keep_attrs | ||
| ), | ||
| kwargs=dict( | ||
| func=func, | ||
| expected_groups=None, | ||
|
||
| axis=axis, | ||
| dtype=None, | ||
| method=None, | ||
| engine=None, | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ), | ||
| ) | ||
|
|
||
| return actual | ||
|
|
||
| # xarray_reduce( | ||
| # obj.drop_vars(non_numeric.keys()), | ||
| # *codes, | ||
| # dim=parsed_dim, | ||
| # expected_groups=expected_groups, | ||
| # isbin=False, | ||
| # keep_attrs=keep_attrs, | ||
| # **kwargs, | ||
| # ) | ||
|
|
||
| def fillna(self, value: Any) -> T_Xarray: | ||
| """Fill missing values in this object by group. | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.