Skip to content

Commit c692076

Browse files
Refactored test parameterization
Co-authored-by: Tiago Firmino <[email protected]>
1 parent 5e3a965 commit c692076

File tree

1 file changed

+96
-37
lines changed

1 file changed

+96
-37
lines changed

pandas/tests/groupby/test_reductions.py

Lines changed: 96 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,63 +1048,122 @@ def scipy_sem(*args, **kwargs):
10481048

10491049

10501050
@pytest.mark.parametrize(
1051-
"data",
1051+
"reduction_method, values",
10521052
[
1053-
{
1054-
"l": ["A", "A", "A", "A", "B", "B", "B", "B"],
1055-
"f": [-1.0, 1.2, -1.1, 1.5, -1.1, 1.5, np.nan, 1.0],
1056-
"s": ["foo", "bar", "baz", "foo", "foo", "foo", pd.NA, "foo"],
1057-
"t": [
1053+
("sum", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1054+
("sum", ["foo", "bar", "baz", "foo", pd.NA, "foo"]),
1055+
("min", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1056+
(
1057+
"min",
1058+
[
10581059
Timestamp("2024-01-01"),
10591060
Timestamp("2024-01-02"),
10601061
Timestamp("2024-01-03"),
1061-
Timestamp("2024-01-04"),
1062-
Timestamp("2024-01-06"),
10631062
Timestamp("2024-01-07"),
10641063
Timestamp("2024-01-08"),
10651064
pd.NaT,
10661065
],
1067-
"td": [
1066+
),
1067+
(
1068+
"min",
1069+
[
10681070
pd.Timedelta(days=1),
10691071
pd.Timedelta(days=2),
10701072
pd.Timedelta(days=3),
1071-
pd.Timedelta(days=4),
1072-
pd.Timedelta(days=6),
10731073
pd.Timedelta(days=7),
10741074
pd.Timedelta(days=8),
10751075
pd.NaT,
10761076
],
1077-
}
1078-
],
1079-
)
1080-
@pytest.mark.parametrize(
1081-
"reduction_method,columns",
1082-
[
1083-
("sum", ["f", "s"]),
1084-
("min", ["f", "t", "td"]),
1085-
("max", ["f", "t", "td"]),
1086-
("mean", ["f", "t", "td"]),
1087-
("median", ["f", "t", "td"]),
1088-
("prod", ["f"]),
1089-
("sem", ["f"]),
1090-
("std", ["f"]),
1091-
("var", ["f"]),
1092-
("any", ["f"]),
1093-
("all", ["f"]),
1094-
("skew", ["f"]),
1077+
),
1078+
("max", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1079+
(
1080+
"max",
1081+
[
1082+
Timestamp("2024-01-01"),
1083+
Timestamp("2024-01-02"),
1084+
Timestamp("2024-01-03"),
1085+
Timestamp("2024-01-07"),
1086+
Timestamp("2024-01-08"),
1087+
pd.NaT,
1088+
],
1089+
),
1090+
(
1091+
"max",
1092+
[
1093+
pd.Timedelta(days=1),
1094+
pd.Timedelta(days=2),
1095+
pd.Timedelta(days=3),
1096+
pd.Timedelta(days=7),
1097+
pd.Timedelta(days=8),
1098+
pd.NaT,
1099+
],
1100+
),
1101+
("mean", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1102+
(
1103+
"mean",
1104+
[
1105+
Timestamp("2024-01-01"),
1106+
Timestamp("2024-01-02"),
1107+
Timestamp("2024-01-03"),
1108+
Timestamp("2024-01-07"),
1109+
Timestamp("2024-01-08"),
1110+
pd.NaT,
1111+
],
1112+
),
1113+
(
1114+
"mean",
1115+
[
1116+
pd.Timedelta(days=1),
1117+
pd.Timedelta(days=2),
1118+
pd.Timedelta(days=3),
1119+
pd.Timedelta(days=7),
1120+
pd.Timedelta(days=8),
1121+
pd.NaT,
1122+
],
1123+
),
1124+
("median", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1125+
(
1126+
"median",
1127+
[
1128+
Timestamp("2024-01-01"),
1129+
Timestamp("2024-01-02"),
1130+
Timestamp("2024-01-03"),
1131+
Timestamp("2024-01-07"),
1132+
Timestamp("2024-01-08"),
1133+
pd.NaT,
1134+
],
1135+
),
1136+
(
1137+
"median",
1138+
[
1139+
pd.Timedelta(days=1),
1140+
pd.Timedelta(days=2),
1141+
pd.Timedelta(days=3),
1142+
pd.Timedelta(days=7),
1143+
pd.Timedelta(days=8),
1144+
pd.NaT,
1145+
],
1146+
),
1147+
("prod", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1148+
("sem", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1149+
("std", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1150+
("var", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1151+
("any", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1152+
("all", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
1153+
("skew", [-1.0, 1.2, -1.1, 1.5, np.nan, 1.0]),
10951154
],
10961155
)
1097-
def test_skipna_reduction_ops_cython(reduction_method, columns, data):
1156+
def test_skipna_reduction_ops_cython(reduction_method, values):
10981157
# GH15675
10991158
# Testing the skipna parameter against possible datatypes
1100-
df = DataFrame(data)
1159+
df = DataFrame({"key": [1, 1, 1, 2, 2, 2], "values": values})
1160+
gb = df.groupby("key")
11011161

1102-
for column in columns:
1103-
result_cython = getattr(df.groupby("l")[column], reduction_method)(skipna=False)
1104-
expected = df.groupby("l")[column].apply(
1105-
lambda x: getattr(x, reduction_method)(skipna=False)
1106-
)
1107-
tm.assert_series_equal(result_cython, expected, check_exact=False)
1162+
result_cython = getattr(gb, reduction_method)(skipna=False)
1163+
expected = gb.apply(
1164+
lambda x: getattr(x, reduction_method)(skipna=False), include_groups=False
1165+
)
1166+
tm.assert_frame_equal(result_cython, expected, check_exact=False)
11081167

11091168

11101169
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)