Skip to content

Commit 558bc25

Browse files
committed
WIP EAs support
1 parent c2ceb57 commit 558bc25

File tree

4 files changed

+78
-36
lines changed

4 files changed

+78
-36
lines changed

pandas/_libs/groupby.pyx

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ cdef enum InterpolationEnumType:
6262
INTERPOLATION_MIDPOINT
6363

6464

65-
cdef float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) noexcept nogil:
65+
cdef float64_t median_linear_mask(
66+
float64_t* a,
67+
int n,
68+
uint8_t* mask,
69+
bint skipna=True
70+
) noexcept nogil:
6671
cdef:
6772
int i, j, na_count = 0
6873
float64_t* tmp
@@ -74,6 +79,8 @@ cdef float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) noexcept n
7479
# count NAs
7580
for i in range(n):
7681
if mask[i]:
82+
if not skipna:
83+
return NaN
7784
na_count += 1
7885

7986
if na_count:
@@ -235,7 +242,7 @@ def group_median_float64(
235242

236243
for j in range(ngroups):
237244
size = _counts[j + 1]
238-
result = median_linear_mask(ptr, size, ptr_mask)
245+
result = median_linear_mask(ptr, size, ptr_mask, skipna)
239246
out[j, i] = result
240247

241248
if result != result:
@@ -739,6 +746,8 @@ def group_sum(
739746
continue
740747

741748
if uses_mask:
749+
if result_mask[lab, j]:
750+
continue
742751
isna_entry = mask[i, j]
743752
else:
744753
isna_entry = _treat_as_na(val, is_datetimelike)
@@ -747,7 +756,10 @@ def group_sum(
747756
if skipna:
748757
continue
749758
else:
750-
sumx[lab, j] = val
759+
if uses_mask:
760+
result_mask[lab, j] = True
761+
else:
762+
sumx[lab, j] = val
751763
compensation[lab, j] = 0
752764
continue
753765

@@ -824,6 +836,8 @@ def group_prod(
824836
val = values[i, j]
825837

826838
if uses_mask:
839+
if result_mask[lab, j]:
840+
continue
827841
isna_entry = mask[i, j]
828842
else:
829843
isna_entry = _treat_as_na(val, False)
@@ -832,7 +846,10 @@ def group_prod(
832846
nobs[lab, j] += 1
833847
prodx[lab, j] *= val
834848
elif not skipna:
835-
prodx[lab, j] = val
849+
if uses_mask:
850+
result_mask[lab, j] = True
851+
else:
852+
prodx[lab, j] = val
836853
nobs[lab, j] = 0
837854
continue
838855

@@ -891,6 +908,8 @@ def group_var(
891908
val = values[i, j]
892909

893910
if uses_mask:
911+
if result_mask[lab, j]:
912+
continue
894913
isna_entry = mask[i, j]
895914
elif is_datetimelike:
896915
# With group_var, we cannot just use _treat_as_na bc
@@ -901,7 +920,10 @@ def group_var(
901920
isna_entry = _treat_as_na(val, is_datetimelike)
902921

903922
if not skipna and isna_entry:
904-
out[lab, j] = val
923+
if uses_mask:
924+
result_mask[lab, j] = True
925+
else:
926+
out[lab, j] = val
905927
nobs[lab, j] = 0
906928
continue
907929

@@ -1100,6 +1122,8 @@ def group_mean(
11001122
val = values[i, j]
11011123

11021124
if uses_mask:
1125+
if result_mask[lab, j]:
1126+
continue
11031127
isna_entry = mask[i, j]
11041128
elif is_datetimelike:
11051129
# With group_mean, we cannot just use _treat_as_na bc
@@ -1110,7 +1134,10 @@ def group_mean(
11101134
isna_entry = _treat_as_na(val, is_datetimelike)
11111135

11121136
if not skipna and isna_entry:
1113-
sumx[lab, j] = val
1137+
if uses_mask:
1138+
result_mask[lab, j] = True
1139+
else:
1140+
sumx[lab, j] = val
11141141
nobs[lab, j] = 0
11151142
continue
11161143

@@ -1762,12 +1789,17 @@ cdef group_min_max(
17621789
val = values[i, j]
17631790

17641791
if uses_mask:
1792+
if result_mask[lab, j]:
1793+
continue
17651794
isna_entry = mask[i, j]
17661795
else:
17671796
isna_entry = _treat_as_na(val, is_datetimelike)
17681797

17691798
if not skipna and isna_entry:
1770-
group_min_or_max[lab, j] = val
1799+
if uses_mask:
1800+
result_mask[lab, j] = True
1801+
else:
1802+
group_min_or_max[lab, j] = val
17711803
nobs[lab, j] = 0
17721804
continue
17731805

pandas/core/arrays/arrow/array.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,12 +2237,6 @@ def _groupby_op(
22372237
ids: npt.NDArray[np.intp],
22382238
**kwargs,
22392239
):
2240-
if how in ["sum", "prod", "mean", "median", "var", "sem", "std", "nim", "max"]:
2241-
if "skipna" in kwargs and not kwargs["skipna"]:
2242-
raise NotImplementedError(
2243-
f"method '{how}' with skipna=False not implemented for Arrow dtypes"
2244-
)
2245-
22462240
if isinstance(self.dtype, StringDtype):
22472241
return super()._groupby_op(
22482242
how=how,
@@ -2308,33 +2302,31 @@ def _str_contains(
23082302
def _str_startswith(self, pat: str | tuple[str, ...], na=None) -> Self:
23092303
if isinstance(pat, str):
23102304
result = pc.starts_with(self._pa_array, pattern=pat)
2305+
elif len(pat) == 0:
2306+
# For empty tuple, pd.StringDtype() returns null for missing values
2307+
# and false for valid values.
2308+
result = pc.if_else(pc.is_null(self._pa_array), None, False)
23112309
else:
2312-
if len(pat) == 0:
2313-
# For empty tuple, pd.StringDtype() returns null for missing values
2314-
# and false for valid values.
2315-
result = pc.if_else(pc.is_null(self._pa_array), None, False)
2316-
else:
2317-
result = pc.starts_with(self._pa_array, pattern=pat[0])
2310+
result = pc.starts_with(self._pa_array, pattern=pat[0])
23182311

2319-
for p in pat[1:]:
2320-
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
2312+
for p in pat[1:]:
2313+
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
23212314
if not isna(na):
23222315
result = result.fill_null(na)
23232316
return type(self)(result)
23242317

23252318
def _str_endswith(self, pat: str | tuple[str, ...], na=None) -> Self:
23262319
if isinstance(pat, str):
23272320
result = pc.ends_with(self._pa_array, pattern=pat)
2321+
elif len(pat) == 0:
2322+
# For empty tuple, pd.StringDtype() returns null for missing values
2323+
# and false for valid values.
2324+
result = pc.if_else(pc.is_null(self._pa_array), None, False)
23282325
else:
2329-
if len(pat) == 0:
2330-
# For empty tuple, pd.StringDtype() returns null for missing values
2331-
# and false for valid values.
2332-
result = pc.if_else(pc.is_null(self._pa_array), None, False)
2333-
else:
2334-
result = pc.ends_with(self._pa_array, pattern=pat[0])
2326+
result = pc.ends_with(self._pa_array, pattern=pat[0])
23352327

2336-
for p in pat[1:]:
2337-
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
2328+
for p in pat[1:]:
2329+
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
23382330
if not isna(na):
23392331
result = result.fill_null(na)
23402332
return type(self)(result)

pandas/core/groupby/ops.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,6 @@ def _call_cython_op(
392392
values[mask] = True
393393
values = values.astype(bool, copy=False).view(np.int8)
394394
is_numeric = True
395-
elif (
396-
self.how in ["median", "sem", "std", "var"]
397-
and "skipna" in kwargs
398-
and not kwargs["skipna"]
399-
):
400-
# if skipna=False we don't want to use masks created for Nullable dtypes
401-
mask = None
402395

403396
values = values.T
404397
if mask is not None:

pandas/tests/extension/base/reduce.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
7777

7878
tm.assert_extension_array_equal(result1, expected)
7979

80+
def check_reduce_groupby(self, ser: pd.Series, op_name: str, skipna: bool):
81+
# Check that groupby reduction behaves correctly
82+
df = pd.DataFrame({"a": ser, "key": [1, 2] * (len(ser) // 2)})
83+
grp = df.groupby("key")["a"]
84+
res_op = getattr(grp, op_name)
85+
86+
expected = grp.apply(
87+
lambda x: getattr(x.astype("float64"), op_name)(skipna=skipna)
88+
)
89+
90+
result = res_op(skipna=skipna)
91+
tm.assert_series_equal(result, expected)
92+
8093
@pytest.mark.parametrize("skipna", [True, False])
8194
def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna):
8295
op_name = all_boolean_reductions
@@ -129,3 +142,15 @@ def test_reduce_frame(self, data, all_numeric_reductions, skipna):
129142
pytest.skip(f"Reduction {op_name} not supported for this dtype")
130143

131144
self.check_reduce_frame(ser, op_name, skipna)
145+
146+
@pytest.mark.parametrize("skipna", [True, False])
147+
def test_reduce_groupby_numeric(self, data, all_numeric_reductions, skipna):
148+
op_name = all_numeric_reductions
149+
ser = pd.Series(data)
150+
if not is_numeric_dtype(ser.dtype):
151+
pytest.skip(f"{ser.dtype} is not numeric dtype")
152+
153+
if not self._supports_reduction(ser, op_name):
154+
pytest.skip(f"Reduction {op_name} not supported for this dtype")
155+
156+
self.check_reduce_groupby(ser, op_name, skipna)

0 commit comments

Comments
 (0)