Skip to content

Commit 8ab2710

Browse files
committed
Merge remote-tracking branch 'upstream/main' into typevar-default
2 parents e415282 + f745945 commit 8ab2710

19 files changed

+377
-146
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
hooks:
1212
- id: isort
1313
- repo: https://github.com/astral-sh/ruff-pre-commit
14-
rev: v0.11.5
14+
rev: v0.11.13
1515
hooks:
1616
- id: ruff
1717
args: [

pandas-stubs/_typing.pyi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ SeriesDType: TypeAlias = (
847847
S1 = TypeVar("S1", bound=SeriesDType, default=Any)
848848
# Like S1, but without `default=Any`.
849849
S2 = TypeVar("S2", bound=SeriesDType)
850+
S3 = TypeVar("S3", bound=SeriesDType)
850851

851852
IndexingInt: TypeAlias = (
852853
int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8
@@ -995,7 +996,12 @@ TimeZones: TypeAlias = str | tzinfo | None | int
995996

996997
# Evaluates to a DataFrame column in DataFrame.assign context.
997998
IntoColumn: TypeAlias = (
998-
AnyArrayLike | Scalar | Callable[[DataFrame], AnyArrayLike | Scalar] | None
999+
AnyArrayLike
1000+
| Scalar
1001+
| Callable[[DataFrame], AnyArrayLike | Scalar | list[Scalar] | range]
1002+
| list[Scalar]
1003+
| range
1004+
| None
9991005
)
10001006

10011007
DatetimeLike: TypeAlias = datetime.datetime | np.datetime64 | Timestamp

pandas-stubs/core/frame.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
22762276
replace: _bool = ...,
22772277
weights: _str | ListLike | None = ...,
22782278
random_state: RandomState | None = ...,
2279-
axis: AxisIndex | None = ...,
2279+
axis: Axis | None = ...,
22802280
ignore_index: _bool = ...,
22812281
) -> Self: ...
22822282
def sem(

pandas-stubs/core/groupby/base.pyi

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,56 @@
11
from collections.abc import Hashable
22
import dataclasses
3+
from typing import (
4+
Literal,
5+
TypeAlias,
6+
)
37

48
@dataclasses.dataclass(order=True, frozen=True)
59
class OutputKey:
610
label: Hashable
711
position: int
12+
13+
ReductionKernelType: TypeAlias = Literal[
14+
"all",
15+
"any",
16+
"corrwith",
17+
"count",
18+
"first",
19+
"idxmax",
20+
"idxmin",
21+
"last",
22+
"max",
23+
"mean",
24+
"median",
25+
"min",
26+
"nunique",
27+
"prod",
28+
# as long as `quantile`'s signature accepts only
29+
# a single quantile value, it's a reduction.
30+
# GH#27526 might change that.
31+
"quantile",
32+
"sem",
33+
"size",
34+
"skew",
35+
"std",
36+
"sum",
37+
"var",
38+
]
39+
40+
TransformationKernelType: TypeAlias = Literal[
41+
"bfill",
42+
"cumcount",
43+
"cummax",
44+
"cummin",
45+
"cumprod",
46+
"cumsum",
47+
"diff",
48+
"ffill",
49+
"fillna",
50+
"ngroup",
51+
"pct_change",
52+
"rank",
53+
"shift",
54+
]
55+
56+
TransformReductionListType: TypeAlias = ReductionKernelType | TransformationKernelType

pandas-stubs/core/groupby/generic.pyi

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from collections.abc import (
77
)
88
from typing import (
99
Any,
10+
Concatenate,
1011
Generic,
1112
Literal,
1213
NamedTuple,
@@ -18,6 +19,7 @@ from typing import (
1819
from matplotlib.axes import Axes as PlotAxes
1920
import numpy as np
2021
from pandas.core.frame import DataFrame
22+
from pandas.core.groupby.base import TransformReductionListType
2123
from pandas.core.groupby.groupby import (
2224
GroupBy,
2325
GroupByPlot,
@@ -31,6 +33,7 @@ from typing_extensions import (
3133
from pandas._libs.tslibs.timestamps import Timestamp
3234
from pandas._typing import (
3335
S2,
36+
S3,
3437
AggFuncTypeBase,
3538
AggFuncTypeFrame,
3639
ByT,
@@ -40,6 +43,7 @@ from pandas._typing import (
4043
Level,
4144
ListLike,
4245
NsmallestNlargestKeep,
46+
P,
4347
Scalar,
4448
TakeIndexer,
4549
WindowingEngine,
@@ -53,10 +57,30 @@ class NamedAgg(NamedTuple):
5357
aggfunc: AggScalar
5458

5559
class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
60+
@overload
61+
def aggregate(
62+
self,
63+
func: Callable[Concatenate[Series[S2], P], S3],
64+
/,
65+
*args,
66+
engine: WindowingEngine = ...,
67+
engine_kwargs: WindowingEngineKwargs = ...,
68+
**kwargs,
69+
) -> Series[S3]: ...
70+
@overload
71+
def aggregate(
72+
self,
73+
func: Callable[[Series], S3],
74+
*args,
75+
engine: WindowingEngine = ...,
76+
engine_kwargs: WindowingEngineKwargs = ...,
77+
**kwargs,
78+
) -> Series[S3]: ...
5679
@overload
5780
def aggregate(
5881
self,
5982
func: list[AggFuncTypeBase],
83+
/,
6084
*args,
6185
engine: WindowingEngine = ...,
6286
engine_kwargs: WindowingEngineKwargs = ...,
@@ -66,19 +90,33 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
6690
def aggregate(
6791
self,
6892
func: AggFuncTypeBase | None = ...,
93+
/,
6994
*args,
7095
engine: WindowingEngine = ...,
7196
engine_kwargs: WindowingEngineKwargs = ...,
7297
**kwargs,
7398
) -> Series: ...
7499
agg = aggregate
100+
@overload
75101
def transform(
76102
self,
77-
func: Callable | str,
78-
*args,
103+
func: Callable[Concatenate[Series[S2], P], Series[S3]],
104+
/,
105+
*args: Any,
79106
engine: WindowingEngine = ...,
80107
engine_kwargs: WindowingEngineKwargs = ...,
81-
**kwargs,
108+
**kwargs: Any,
109+
) -> Series[S3]: ...
110+
@overload
111+
def transform(
112+
self,
113+
func: Callable,
114+
*args: Any,
115+
**kwargs: Any,
116+
) -> Series: ...
117+
@overload
118+
def transform(
119+
self, func: TransformReductionListType, *args, **kwargs
82120
) -> Series: ...
83121
def filter(
84122
self, func: Callable | str, dropna: bool = ..., *args, **kwargs
@@ -206,13 +244,25 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
206244
**kwargs,
207245
) -> DataFrame: ...
208246
agg = aggregate
247+
@overload
209248
def transform(
210249
self,
211-
func: Callable | str,
212-
*args,
250+
func: Callable[Concatenate[DataFrame, P], DataFrame],
251+
*args: Any,
213252
engine: WindowingEngine = ...,
214253
engine_kwargs: WindowingEngineKwargs = ...,
215-
**kwargs,
254+
**kwargs: Any,
255+
) -> DataFrame: ...
256+
@overload
257+
def transform(
258+
self,
259+
func: Callable,
260+
*args: Any,
261+
**kwargs: Any,
262+
) -> DataFrame: ...
263+
@overload
264+
def transform(
265+
self, func: TransformReductionListType, *args, **kwargs
216266
) -> DataFrame: ...
217267
def filter(
218268
self, func: Callable, dropna: bool = ..., *args, **kwargs

pandas-stubs/core/indexes/multi.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class MultiIndex(Index):
5454
@classmethod
5555
def from_product(
5656
cls,
57-
iterables: Sequence[SequenceNotStr[Hashable]],
57+
iterables: Sequence[SequenceNotStr[Hashable] | pd.Series | pd.Index],
5858
sortorder: int | None = ...,
5959
names: SequenceNotStr[Hashable] = ...,
6060
) -> Self: ...

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ ty = "^0.0.1a8"
3737

3838
[tool.poetry.group.dev.dependencies]
3939
mypy = "1.16.0"
40-
pandas = "2.2.3"
40+
pandas = "2.3.0"
4141
pyarrow = ">=10.0.1"
4242
pytest = ">=7.1.2"
4343
pyright = ">=1.1.400"

tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
TYPE_CHECKING_INVALID_USAGE: Final = TYPE_CHECKING
5252
WINDOWS = os.name == "nt" or "cygwin" in platform.system().lower()
53-
PD_LTE_22 = Version(pd.__version__) < Version("2.2.999")
53+
PD_LTE_23 = Version(pd.__version__) < Version("2.3.999")
5454
NUMPY20 = np.lib.NumpyVersion(np.__version__) >= "2.0.0"
5555

5656

tests/test_errors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from tests import (
7-
PD_LTE_22,
7+
PD_LTE_23,
88
WINDOWS,
99
)
1010

@@ -108,13 +108,13 @@ def test_specification_error() -> None:
108108

109109

110110
def test_setting_with_copy_error() -> None:
111-
if PD_LTE_22:
111+
if PD_LTE_23:
112112
with pytest.raises(errors.SettingWithCopyError):
113113
raise errors.SettingWithCopyError()
114114

115115

116116
def test_setting_with_copy_warning() -> None:
117-
if PD_LTE_22:
117+
if PD_LTE_23:
118118
with pytest.warns(errors.SettingWithCopyWarning):
119119
warnings.warn("", errors.SettingWithCopyWarning)
120120

0 commit comments

Comments
 (0)