@@ -413,6 +413,7 @@ def chunk_reduce(
413
413
reindex : bool = False ,
414
414
isbin : bool = False ,
415
415
backend : str = "numpy" ,
416
+ kwargs = None ,
416
417
) -> IntermediateDict :
417
418
"""
418
419
Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -458,6 +459,9 @@ def chunk_reduce(
458
459
if not isinstance (fill_value , Sequence ):
459
460
fill_value = (fill_value ,)
460
461
462
+ if kwargs is None :
463
+ kwargs = ({},) * len (func )
464
+
461
465
# when axis is a tuple
462
466
# collapse and move reduction dimensions to the end
463
467
if isinstance (axis , Sequence ) and len (axis ) < by .ndim :
@@ -503,7 +507,7 @@ def chunk_reduce(
503
507
final_array_shape += results ["groups" ].shape
504
508
final_groups_shape += results ["groups" ].shape
505
509
506
- for reduction , fv in zip (func , fill_value ):
510
+ for reduction , fv , kw in zip (func , fill_value , kwargs ):
507
511
if empty :
508
512
result = np .full (shape = final_array_shape , fill_value = fv )
509
513
else :
@@ -516,6 +520,7 @@ def chunk_reduce(
516
520
size = size ,
517
521
# important when reducing with "offset" groups
518
522
fill_value = fv ,
523
+ ** kw ,
519
524
)
520
525
else :
521
526
result = _get_aggregate (backend )(
@@ -527,6 +532,7 @@ def chunk_reduce(
527
532
# important when reducing with "offset" groups
528
533
fill_value = fv ,
529
534
dtype = np .intp if reduction == "nanlen" else dtype ,
535
+ ** kw ,
530
536
)
531
537
if np .any (~ mask ):
532
538
# remove NaN group label which should be last
@@ -573,6 +579,7 @@ def _finalize_results(
573
579
expected_groups : Union [Sequence , np .ndarray , None ],
574
580
fill_value : Any ,
575
581
min_count : Optional [int ] = None ,
582
+ finalize_kwargs : Optional [Mapping ] = None ,
576
583
):
577
584
"""Finalize results by
578
585
1. Squeezing out dummy dimensions
@@ -595,10 +602,11 @@ def _finalize_results(
595
602
if fill_value is not None :
596
603
counts = squeezed ["intermediates" ][- 1 ]
597
604
squeezed ["intermediates" ] = squeezed ["intermediates" ][:- 1 ]
598
-
599
605
if min_count is None :
600
606
min_count = 1
601
- result [agg .name ] = agg .finalize (* squeezed ["intermediates" ])
607
+ if finalize_kwargs is None :
608
+ finalize_kwargs = {}
609
+ result [agg .name ] = agg .finalize (* squeezed ["intermediates" ], ** finalize_kwargs )
602
610
result [agg .name ] = np .where (counts >= min_count , result [agg .name ], fill_value )
603
611
604
612
# Final reindexing has to be here to be lazy
@@ -621,10 +629,13 @@ def _npg_aggregate(
621
629
fill_value : Any = None ,
622
630
min_count : Optional [int ] = None ,
623
631
backend : str = "numpy" ,
632
+ finalize_kwargs : Optional [Mapping ] = None ,
624
633
) -> FinalResultsDict :
625
634
"""Final aggregation step of tree reduction"""
626
635
results = _npg_combine (x_chunk , agg , axis , keepdims , group_ndim , backend )
627
- return _finalize_results (results , agg , axis , expected_groups , fill_value , min_count )
636
+ return _finalize_results (
637
+ results , agg , axis , expected_groups , fill_value , min_count , finalize_kwargs
638
+ )
628
639
629
640
630
641
def _npg_combine (
@@ -782,6 +793,7 @@ def groupby_agg(
782
793
min_count : Optional [int ] = None ,
783
794
isbin : bool = False ,
784
795
backend : str = "numpy" ,
796
+ finalize_kwargs : Optional [Mapping ] = None ,
785
797
) -> Tuple ["DaskArray" , Union [np .ndarray , "DaskArray" ]]:
786
798
787
799
import dask .array
@@ -851,6 +863,14 @@ def groupby_agg(
851
863
group_chunks = (len (expected_groups ),) if expected_groups is not None else (np .nan ,)
852
864
expected_agg = expected_groups
853
865
866
+ agg_kwargs = dict (
867
+ group_ndim = by .ndim ,
868
+ fill_value = fill_value ,
869
+ min_count = min_count ,
870
+ backend = backend ,
871
+ finalize_kwargs = finalize_kwargs ,
872
+ )
873
+
854
874
if method == "mapreduce" :
855
875
# reduced is really a dict mapping reduction name to array
856
876
# and "groups" to an array of group labels
@@ -862,10 +882,7 @@ def groupby_agg(
862
882
_npg_aggregate ,
863
883
agg = agg ,
864
884
expected_groups = expected_agg ,
865
- group_ndim = by .ndim ,
866
- fill_value = fill_value ,
867
- min_count = min_count ,
868
- backend = backend ,
885
+ ** agg_kwargs ,
869
886
),
870
887
combine = partial (_npg_combine , agg = agg , group_ndim = by .ndim , backend = backend ),
871
888
name = f"{ name } -reduce" ,
@@ -892,10 +909,7 @@ def groupby_agg(
892
909
_npg_aggregate ,
893
910
agg = agg ,
894
911
expected_groups = None ,
895
- group_ndim = by .ndim ,
896
- fill_value = fill_value ,
897
- min_count = min_count ,
898
- backend = backend ,
912
+ ** agg_kwargs ,
899
913
axis = axis ,
900
914
keepdims = True ,
901
915
),
@@ -982,6 +996,7 @@ def groupby_reduce(
982
996
split_out : int = 1 ,
983
997
method : str = "mapreduce" ,
984
998
backend : str = "numpy" ,
999
+ finalize_kwargs : Optional [Mapping ] = None ,
985
1000
) -> Tuple ["DaskArray" , Union [np .ndarray , "DaskArray" ]]:
986
1001
"""
987
1002
GroupBy reductions using tree reductions for dask.array
@@ -1026,6 +1041,8 @@ def groupby_reduce(
1026
1041
chunking ``array`` for this method by first rechunking using ``rechunk_for_cohorts``.
1027
1042
backend: {"numpy", "numba"}, optional
1028
1043
Backend for numpy_groupies. numpy by default.
1044
+ finalize_kwargs: Mapping, optional
1045
+ Kwargs passed to finalize the reduction such as ddof for var, std.
1029
1046
1030
1047
Returns
1031
1048
-------
@@ -1112,18 +1129,25 @@ def groupby_reduce(
1112
1129
reduction .finalize = None
1113
1130
# xarray's count is npg's nanlen
1114
1131
func = reduction .name if reduction .name != "count" else "nanlen"
1115
- if min_count is not None :
1132
+ if finalize_kwargs is None :
1133
+ finalize_kwargs = {}
1134
+ if isinstance (finalize_kwargs , Mapping ):
1135
+ finalize_kwargs = (finalize_kwargs ,)
1136
+ append_nanlen = min_count is not None or reduction .name in ["nanvar" , "nanstd" ]
1137
+ if append_nanlen :
1116
1138
func = (func , "nanlen" )
1139
+ finalize_kwargs = finalize_kwargs + ({},)
1117
1140
1118
1141
results = chunk_reduce (
1119
1142
array ,
1120
1143
by ,
1121
1144
func = func ,
1122
1145
axis = axis ,
1123
1146
expected_groups = expected_groups if isbin else None ,
1124
- fill_value = (fill_value , 0 ) if min_count is not None else fill_value ,
1147
+ fill_value = (fill_value , 0 ) if append_nanlen else fill_value ,
1125
1148
dtype = reduction .dtype ,
1126
1149
isbin = isbin ,
1150
+ kwargs = finalize_kwargs ,
1127
1151
) # type: ignore
1128
1152
1129
1153
if reduction .name in ["argmin" , "argmax" , "nanargmax" , "nanargmin" ]:
@@ -1133,6 +1157,12 @@ def groupby_reduce(
1133
1157
results ["intermediates" ][0 ] = np .unravel_index (
1134
1158
results ["intermediates" ][0 ], array .shape
1135
1159
)[- 1 ]
1160
+ elif reduction .name in ["nanvar" , "nanstd" ]:
1161
+ # Fix npg bug where all-NaN rows are 0 instead of NaN
1162
+ value , counts = results ["intermediates" ]
1163
+ mask = counts <= 0
1164
+ value [mask ] = np .nan
1165
+ results ["intermediates" ] = (value ,)
1136
1166
1137
1167
if isbin :
1138
1168
expected_groups = np .arange (len (expected_groups ) - 1 )
@@ -1167,6 +1197,7 @@ def groupby_reduce(
1167
1197
min_count = min_count ,
1168
1198
isbin = isbin ,
1169
1199
backend = backend ,
1200
+ finalize_kwargs = finalize_kwargs ,
1170
1201
)
1171
1202
if method == "cohorts" :
1172
1203
assert len (axis ) == 1
0 commit comments