Skip to content

Commit 7428689

Browse files
committed
Support skipna in core.groupby_reduce
This is the right place to do it since skipna can be optionally True for appropriate dtypes. It's hard to choose the right one when passing Dataset to apply_ufunc. Instead we do it here where we are always dealing with a single array and can choose based on dtype.
1 parent 51297ed commit 7428689

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

dask_groupby/core.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ def chunk_argreduce(
389389
if not np.isnan(results["groups"]).all():
390390
# will not work for empty groups...
391391
# glorious
392+
# TODO: npg bug
393+
results["intermediates"][1] = results["intermediates"][1].astype(int)
392394
newidx = np.broadcast_to(idx, array.shape)[
393395
np.unravel_index(results["intermediates"][1], array.shape)
394396
]
@@ -992,6 +994,7 @@ def groupby_reduce(
992994
isbin: bool = False,
993995
axis=None,
994996
fill_value=None,
997+
skipna: Optional[bool] = None,
995998
min_count: Optional[int] = None,
996999
split_out: int = 1,
9971000
method: str = "mapreduce",
@@ -1020,6 +1023,16 @@ def groupby_reduce(
10201023
Negative integers are normalized using array.ndim
10211024
fill_value: Any
10221025
Value when a label in `expected_groups` is not present
1026+
skipna : bool, default: None
1027+
If True, skip missing values (as marked by NaN). By default, only
1028+
skips missing values for float dtypes; other dtypes either do not
1029+
have a sentinel missing value (int) or ``skipna=True`` has not been
1030+
implemented (object, datetime64 or timedelta64).
1031+
min_count : int, default: None
1032+
The required number of valid values to perform the operation. If
1033+
fewer than min_count non-NA values are present the result will be
1034+
NA. Only used if skipna is set to True or defaults to True for the
1035+
array's dtype.
10231036
split_out: int, optional
10241037
Number of chunks along group axis in output (last axis)
10251038
method: {"mapreduce", "blockwise", "cohorts"}, optional
@@ -1062,10 +1075,24 @@ def groupby_reduce(
10621075
f"Received array of shape {array.shape} and by of shape {by.shape}"
10631076
)
10641077

1065-
if min_count is not None and min_count > 1 and func not in ["nansum", "nanprod"]:
1066-
raise ValueError(
1067-
"min_count can be > 1 only for nansum, nanprod. This is an Xarray limitation."
1068-
)
1078+
# Handle skipna here because I need to know dtype to make a good default choice.
1079+
# We cannnot handle this easily for xarray Datasets in xarray_reduce
1080+
if skipna and func in ["all", "any", "count"]:
1081+
raise ValueError(f"skipna cannot be truthy for {func} reductions.")
1082+
1083+
if skipna or (skipna is None and array.dtype.kind in "cfO"):
1084+
if "nan" not in func and func not in ["all", "any", "count"]:
1085+
func = f"nan{func}"
1086+
1087+
if min_count is not None and min_count > 1:
1088+
if func not in ["nansum", "nanprod"]:
1089+
raise ValueError(
1090+
"min_count can be > 1 only for nansum, nanprod."
1091+
" or for sum, prod with skipna=True."
1092+
" This is an Xarray limitation."
1093+
)
1094+
elif "nan" not in func and skipna:
1095+
func = f"nan{func}"
10691096

10701097
if axis is None:
10711098
axis = tuple(array.ndim + np.arange(-by.ndim, 0))

dask_groupby/xarray.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def xarray_reduce(
6060
method: str = "mapreduce",
6161
backend: str = "numpy",
6262
keep_attrs: bool = True,
63-
skipna: bool = True,
63+
skipna: Optional[bool] = None,
6464
min_count: Optional[int] = None,
6565
**finalize_kwargs,
6666
):
@@ -113,10 +113,15 @@ def xarray_reduce(
113113
keep_attrs: bool, optional
114114
Preserve attrs?
115115
skipna: bool, optional
116-
Use NaN-skipping aggregations like nanmean?
117-
min_count: int, optional
118-
NaN out when number of non-NaN values in aggregation is < min_count
119-
Only applies to nansum, nanprod.
116+
If True, skip missing values (as marked by NaN). By default, only
117+
skips missing values for float dtypes; other dtypes either do not
118+
have a sentinel missing value (int) or ``skipna=True`` has not been
119+
implemented (object, datetime64 or timedelta64).
120+
min_count : int, default: None
121+
The required number of valid values to perform the operation. If
122+
fewer than min_count non-NA values are present the result will be
123+
NA. Only used if skipna is set to True or defaults to True for the
124+
array's dtype.
120125
finalize_kwargs: dict, optional
121126
kwargs passed to the finalize function, like ddof for var, std.
122127
@@ -130,9 +135,6 @@ def xarray_reduce(
130135
FIXME: Add docs.
131136
"""
132137

133-
if (skipna or min_count is not None) and func not in ["all", "any", "count"]:
134-
func = f"nan{func}"
135-
136138
for b in by:
137139
if isinstance(b, xr.DataArray) and b.name is None:
138140
raise ValueError("Cannot group by unnamed DataArrays.")
@@ -285,6 +287,7 @@ def wrapper(*args, **kwargs):
285287
"fill_value": fill_value,
286288
"method": method,
287289
"min_count": min_count,
290+
"skipna": skipna,
288291
"backend": backend,
289292
# The following mess exists becuase for multiple `by`s I factorize eagerly
290293
# here before passing it on; this means I have to handle the

0 commit comments

Comments
 (0)