Skip to content

Commit e87e030

Browse files
Added tests for EAs
Co-authored-by: André Correia <[email protected]>
1 parent c692076 commit e87e030

File tree

4 files changed

+126
-5
lines changed

4 files changed

+126
-5
lines changed

pandas/_libs/groupby.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,9 @@ def group_sum(
735735
for j in range(K):
736736
val = values[i, j]
737737

738+
if _treat_as_na(sumx[lab, j], is_datetimelike):
739+
continue
740+
738741
if uses_mask:
739742
isna_entry = mask[i, j]
740743
else:
@@ -1107,7 +1110,7 @@ def group_mean(
11071110
isna_entry = _treat_as_na(val, is_datetimelike)
11081111

11091112
if not skipna and isna_entry:
1110-
sumx[lab, j] = nan_val
1113+
sumx[lab, j] = val
11111114
nobs[lab, j] = 0
11121115
continue
11131116

pandas/core/arrays/arrow/array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2215,7 +2215,9 @@ def _replace_with_mask(
22152215
def _to_masked(self):
22162216
pa_dtype = self._pa_array.type
22172217

2218-
if pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype):
2218+
if pa.types.is_floating(pa_dtype):
2219+
na_value = np.nan
2220+
elif pa.types.is_integer(pa_dtype):
22192221
na_value = 1
22202222
elif pa.types.is_boolean(pa_dtype):
22212223
na_value = True

pandas/core/groupby/ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ def _call_cython_op(
392392
values[mask] = True
393393
values = values.astype(bool, copy=False).view(np.int8)
394394
is_numeric = True
395+
elif (
396+
self.how in ["median", "sem", "std", "var"]
397+
and "skipna" in kwargs
398+
and not kwargs["skipna"]
399+
):
400+
# if skipna=False we don't want to use masks created for Nullable dtypes
401+
mask = None
395402

396403
values = values.T
397404
if mask is not None:
@@ -1257,4 +1264,4 @@ def _get_splitter(
12571264
# i.e. DataFrame
12581265
klass = FrameSplitter
12591266

1260-
return klass(data, labels, ngroups, sort_idx=sort_idx, sorted_ids=sorted_ids)
1267+
return klass(data, labels, ngroups, sort_idx=sort_idx, sorted_ids=sorted_ids)

pandas/tests/groupby/test_reductions.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from string import ascii_lowercase
44

55
import numpy as np
6+
import pyarrow as pa
67
import pytest
78

89
from pandas._libs.tslibs import iNaT
@@ -1052,7 +1053,31 @@ def scipy_sem(*args, **kwargs):
10521053
[
10531054
("sum", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
10541055
("sum", ["foo", "bar", "baz", "foo", pd.NA, "foo"]),
1056+
(
1057+
"sum",
1058+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1059+
),
1060+
(
1061+
"sum",
1062+
Series(
1063+
pd.array(
1064+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1065+
)
1066+
),
1067+
),
10551068
("min", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1069+
(
1070+
"min",
1071+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1072+
),
1073+
(
1074+
"min",
1075+
Series(
1076+
pd.array(
1077+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1078+
)
1079+
),
1080+
),
10561081
(
10571082
"min",
10581083
[
@@ -1076,6 +1101,18 @@ def scipy_sem(*args, **kwargs):
10761101
],
10771102
),
10781103
("max", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1104+
(
1105+
"max",
1106+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1107+
),
1108+
(
1109+
"max",
1110+
Series(
1111+
pd.array(
1112+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1113+
)
1114+
),
1115+
),
10791116
(
10801117
"max",
10811118
[
@@ -1099,6 +1136,18 @@ def scipy_sem(*args, **kwargs):
10991136
],
11001137
),
11011138
("mean", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1139+
(
1140+
"mean",
1141+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1142+
),
1143+
(
1144+
"mean",
1145+
Series(
1146+
pd.array(
1147+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1148+
)
1149+
),
1150+
),
11021151
(
11031152
"mean",
11041153
[
@@ -1122,6 +1171,18 @@ def scipy_sem(*args, **kwargs):
11221171
],
11231172
),
11241173
("median", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1174+
(
1175+
"median",
1176+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1177+
),
1178+
(
1179+
"median",
1180+
Series(
1181+
pd.array(
1182+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1183+
)
1184+
),
1185+
),
11251186
(
11261187
"median",
11271188
[
@@ -1145,9 +1206,57 @@ def scipy_sem(*args, **kwargs):
11451206
],
11461207
),
11471208
("prod", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1209+
(
1210+
"prod",
1211+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1212+
),
1213+
(
1214+
"prod",
1215+
Series(
1216+
pd.array(
1217+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1218+
)
1219+
),
1220+
),
11481221
("sem", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1222+
(
1223+
"sem",
1224+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1225+
),
1226+
(
1227+
"sem",
1228+
Series(
1229+
pd.array(
1230+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1231+
)
1232+
),
1233+
),
11491234
("std", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1235+
(
1236+
"std",
1237+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1238+
),
1239+
(
1240+
"std",
1241+
Series(
1242+
pd.array(
1243+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1244+
)
1245+
),
1246+
),
11501247
("var", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1248+
(
1249+
"var",
1250+
Series(pd.array([-1.0, 1.2, -1.1, 1.5, np.nan, 1.0], dtype="Float64")),
1251+
),
1252+
(
1253+
"var",
1254+
Series(
1255+
pd.array(
1256+
[1.0, 2.0, 3.0, np.nan, 4.0, 5.0], dtype=pd.ArrowDtype(pa.float64())
1257+
)
1258+
),
1259+
),
11511260
("any", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
11521261
("all", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
11531262
("skew", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
@@ -1163,7 +1272,7 @@ def test_skipna_reduction_ops_cython(reduction_method, values):
11631272
expected = gb.apply(
11641273
lambda x: getattr(x, reduction_method)(skipna=False), include_groups=False
11651274
)
1166-
tm.assert_frame_equal(result_cython, expected, check_exact=False)
1275+
tm.assert_frame_equal(result_cython, expected, check_exact=False, check_dtype=False)
11671276

11681277

11691278
@pytest.mark.parametrize(
@@ -1310,4 +1419,4 @@ def test_groupby_std_datetimelike():
13101419
td4 = pd.Timedelta("2886 days 00:42:34.664668096")
13111420
exp_ser = Series([td1 * 2, td1, td1, td1, td4], index=np.arange(5))
13121421
expected = DataFrame({"A": exp_ser, "B": exp_ser, "C": exp_ser})
1313-
tm.assert_frame_equal(result, expected)
1422+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)