Skip to content

Commit 95701d4

Browse files
committed
Support reduction-specific kwargs in finalize
E.g. ddof for var, std
1 parent ec44bb5 commit 95701d4

File tree

4 files changed

+73
-26
lines changed

4 files changed

+73
-26
lines changed

dask_groupby/aggregations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def nansum_of_squares(group_idx, array, size=None, fill_value=None):
147147
# TODO: fix this for complex numbers
148148
def _var_finalize(sumsq, sum_, count, ddof=0):
149149
result = (sumsq - (sum_ ** 2 / count)) / (count - ddof)
150-
result[(count - ddof) <= 0] = np.nan
150+
result[count <= ddof] = np.nan
151151
return result
152152

153153

dask_groupby/core.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ def chunk_reduce(
413413
reindex: bool = False,
414414
isbin: bool = False,
415415
backend: str = "numpy",
416+
kwargs=None,
416417
) -> IntermediateDict:
417418
"""
418419
Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -458,6 +459,9 @@ def chunk_reduce(
458459
if not isinstance(fill_value, Sequence):
459460
fill_value = (fill_value,)
460461

462+
if kwargs is None:
463+
kwargs = ({},) * len(func)
464+
461465
# when axis is a tuple
462466
# collapse and move reduction dimensions to the end
463467
if isinstance(axis, Sequence) and len(axis) < by.ndim:
@@ -503,7 +507,7 @@ def chunk_reduce(
503507
final_array_shape += results["groups"].shape
504508
final_groups_shape += results["groups"].shape
505509

506-
for reduction, fv in zip(func, fill_value):
510+
for reduction, fv, kw in zip(func, fill_value, kwargs):
507511
if empty:
508512
result = np.full(shape=final_array_shape, fill_value=fv)
509513
else:
@@ -516,6 +520,7 @@ def chunk_reduce(
516520
size=size,
517521
# important when reducing with "offset" groups
518522
fill_value=fv,
523+
**kw,
519524
)
520525
else:
521526
result = _get_aggregate(backend)(
@@ -527,6 +532,7 @@ def chunk_reduce(
527532
# important when reducing with "offset" groups
528533
fill_value=fv,
529534
dtype=np.intp if reduction == "nanlen" else dtype,
535+
**kw,
530536
)
531537
if np.any(~mask):
532538
# remove NaN group label which should be last
@@ -573,6 +579,7 @@ def _finalize_results(
573579
expected_groups: Union[Sequence, np.ndarray, None],
574580
fill_value: Any,
575581
min_count: Optional[int] = None,
582+
finalize_kwargs: Optional[Mapping] = None,
576583
):
577584
"""Finalize results by
578585
1. Squeezing out dummy dimensions
@@ -595,10 +602,11 @@ def _finalize_results(
595602
if fill_value is not None:
596603
counts = squeezed["intermediates"][-1]
597604
squeezed["intermediates"] = squeezed["intermediates"][:-1]
598-
599605
if min_count is None:
600606
min_count = 1
601-
result[agg.name] = agg.finalize(*squeezed["intermediates"])
607+
if finalize_kwargs is None:
608+
finalize_kwargs = {}
609+
result[agg.name] = agg.finalize(*squeezed["intermediates"], **finalize_kwargs)
602610
result[agg.name] = np.where(counts >= min_count, result[agg.name], fill_value)
603611

604612
# Final reindexing has to be here to be lazy
@@ -621,10 +629,13 @@ def _npg_aggregate(
621629
fill_value: Any = None,
622630
min_count: Optional[int] = None,
623631
backend: str = "numpy",
632+
finalize_kwargs: Optional[Mapping] = None,
624633
) -> FinalResultsDict:
625634
"""Final aggregation step of tree reduction"""
626635
results = _npg_combine(x_chunk, agg, axis, keepdims, group_ndim, backend)
627-
return _finalize_results(results, agg, axis, expected_groups, fill_value, min_count)
636+
return _finalize_results(
637+
results, agg, axis, expected_groups, fill_value, min_count, finalize_kwargs
638+
)
628639

629640

630641
def _npg_combine(
@@ -782,6 +793,7 @@ def groupby_agg(
782793
min_count: Optional[int] = None,
783794
isbin: bool = False,
784795
backend: str = "numpy",
796+
finalize_kwargs: Optional[Mapping] = None,
785797
) -> Tuple["DaskArray", Union[np.ndarray, "DaskArray"]]:
786798

787799
import dask.array
@@ -851,6 +863,14 @@ def groupby_agg(
851863
group_chunks = (len(expected_groups),) if expected_groups is not None else (np.nan,)
852864
expected_agg = expected_groups
853865

866+
agg_kwargs = dict(
867+
group_ndim=by.ndim,
868+
fill_value=fill_value,
869+
min_count=min_count,
870+
backend=backend,
871+
finalize_kwargs=finalize_kwargs,
872+
)
873+
854874
if method == "mapreduce":
855875
# reduced is really a dict mapping reduction name to array
856876
# and "groups" to an array of group labels
@@ -862,10 +882,7 @@ def groupby_agg(
862882
_npg_aggregate,
863883
agg=agg,
864884
expected_groups=expected_agg,
865-
group_ndim=by.ndim,
866-
fill_value=fill_value,
867-
min_count=min_count,
868-
backend=backend,
885+
**agg_kwargs,
869886
),
870887
combine=partial(_npg_combine, agg=agg, group_ndim=by.ndim, backend=backend),
871888
name=f"{name}-reduce",
@@ -892,10 +909,7 @@ def groupby_agg(
892909
_npg_aggregate,
893910
agg=agg,
894911
expected_groups=None,
895-
group_ndim=by.ndim,
896-
fill_value=fill_value,
897-
min_count=min_count,
898-
backend=backend,
912+
**agg_kwargs,
899913
axis=axis,
900914
keepdims=True,
901915
),
@@ -982,6 +996,7 @@ def groupby_reduce(
982996
split_out: int = 1,
983997
method: str = "mapreduce",
984998
backend: str = "numpy",
999+
finalize_kwargs: Optional[Mapping] = None,
9851000
) -> Tuple["DaskArray", Union[np.ndarray, "DaskArray"]]:
9861001
"""
9871002
GroupBy reductions using tree reductions for dask.array
@@ -1026,6 +1041,8 @@ def groupby_reduce(
10261041
chunking ``array`` for this method by first rechunking using ``rechunk_for_cohorts``.
10271042
backend: {"numpy", "numba"}, optional
10281043
Backend for numpy_groupies. numpy by default.
1044+
finalize_kwargs: Mapping, optional
1045+
Kwargs passed to finalize the reduction such as ddof for var, std.
10291046
10301047
Returns
10311048
-------
@@ -1112,18 +1129,25 @@ def groupby_reduce(
11121129
reduction.finalize = None
11131130
# xarray's count is npg's nanlen
11141131
func = reduction.name if reduction.name != "count" else "nanlen"
1115-
if min_count is not None:
1132+
if finalize_kwargs is None:
1133+
finalize_kwargs = {}
1134+
if isinstance(finalize_kwargs, Mapping):
1135+
finalize_kwargs = (finalize_kwargs,)
1136+
append_nanlen = min_count is not None or reduction.name in ["nanvar", "nanstd"]
1137+
if append_nanlen:
11161138
func = (func, "nanlen")
1139+
finalize_kwargs = finalize_kwargs + ({},)
11171140

11181141
results = chunk_reduce(
11191142
array,
11201143
by,
11211144
func=func,
11221145
axis=axis,
11231146
expected_groups=expected_groups if isbin else None,
1124-
fill_value=(fill_value, 0) if min_count is not None else fill_value,
1147+
fill_value=(fill_value, 0) if append_nanlen else fill_value,
11251148
dtype=reduction.dtype,
11261149
isbin=isbin,
1150+
kwargs=finalize_kwargs,
11271151
) # type: ignore
11281152

11291153
if reduction.name in ["argmin", "argmax", "nanargmax", "nanargmin"]:
@@ -1133,6 +1157,12 @@ def groupby_reduce(
11331157
results["intermediates"][0] = np.unravel_index(
11341158
results["intermediates"][0], array.shape
11351159
)[-1]
1160+
elif reduction.name in ["nanvar", "nanstd"]:
1161+
# Fix npg bug where all-NaN rows are 0 instead of NaN
1162+
value, counts = results["intermediates"]
1163+
mask = counts <= 0
1164+
value[mask] = np.nan
1165+
results["intermediates"] = (value,)
11361166

11371167
if isbin:
11381168
expected_groups = np.arange(len(expected_groups) - 1)
@@ -1167,6 +1197,7 @@ def groupby_reduce(
11671197
min_count=min_count,
11681198
isbin=isbin,
11691199
backend=backend,
1200+
finalize_kwargs=finalize_kwargs,
11701201
)
11711202
if method == "cohorts":
11721203
assert len(axis) == 1

dask_groupby/xarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def xarray_reduce(
6262
keep_attrs: bool = True,
6363
skipna: bool = True,
6464
min_count: Optional[int] = None,
65+
**finalize_kwargs,
6566
):
6667
"""GroupBy reduce operations on xarray objects using numpy-groupies
6768
@@ -116,6 +117,8 @@ def xarray_reduce(
116117
min_count: int, optional
117118
NaN out when number of non-NaN values in aggregation is < min_count
118119
Only applies to nansum, nanprod.
120+
finalize_kwargs: dict, optional
121+
kwargs passed to the finalize function, like ddof for var, std.
119122
120123
Raises
121124
------
@@ -291,6 +294,7 @@ def wrapper(*args, **kwargs):
291294
# from "by" so we need the isbin part of the condition
292295
"expected_groups": expected_groups[0] if len(by) == 1 and isbin[0] else None,
293296
"isbin": isbin[0] if len(by) == 1 else False,
297+
"finalize_kwargs": finalize_kwargs,
294298
},
295299
)
296300

tests/test_core.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def test_groupby_reduce(
114114
pytest.param("nanargmin", marks=(pytest.mark.xfail,)),
115115
"any",
116116
"all",
117+
pytest.param("median", marks=(pytest.mark.skip,)),
118+
pytest.param("nanmedian", marks=(pytest.mark.skip,)),
117119
),
118120
)
119121
def test_groupby_reduce_all(size, func, backend):
@@ -128,23 +130,33 @@ def test_groupby_reduce_all(size, func, backend):
128130
if func in ["any", "all"]:
129131
array = array > 0.5
130132

131-
with np.errstate(invalid="ignore", divide="ignore"):
132-
expected = getattr(np, func)(array, axis=-1)
133-
expected = np.expand_dims(expected, -1)
133+
finalize_kwargs = tuple({})
134+
if "var" in func or "std" in func:
135+
finalize_kwargs = finalize_kwargs + ({"ddof": 1}, {"ddof": 0})
134136

135-
actual, _ = groupby_reduce(array, by, func=func, backend=backend)
136-
if "arg" in func:
137-
assert actual.dtype.kind == "i"
138-
assert_equal(actual, expected)
137+
for kwargs in finalize_kwargs:
138+
with np.errstate(invalid="ignore", divide="ignore"):
139+
expected = getattr(np, func)(array, axis=-1, **kwargs)
140+
expected = np.expand_dims(expected, -1)
139141

140-
for method in ["mapreduce", "cohorts"]:
141-
actual, _ = groupby_reduce(
142-
da.from_array(array, chunks=3), by, func=func, method=method, backend=backend
143-
)
142+
actual, _ = groupby_reduce(array, by, func=func, backend=backend, finalize_kwargs=kwargs)
144143
if "arg" in func:
145144
assert actual.dtype.kind == "i"
146145
assert_equal(actual, expected)
147146

147+
for method in ["mapreduce", "cohorts"]:
148+
actual, _ = groupby_reduce(
149+
da.from_array(array, chunks=3),
150+
by,
151+
func=func,
152+
method=method,
153+
backend=backend,
154+
finalize_kwargs=kwargs,
155+
)
156+
if "arg" in func:
157+
assert actual.dtype.kind == "i"
158+
assert_equal(actual, expected)
159+
148160

149161
@pytest.mark.parametrize("size", ((12,), (12, 5)))
150162
@pytest.mark.parametrize("func", ("argmax", "nanargmax", "argmin", "nanargmin"))

0 commit comments

Comments
 (0)