Skip to content

Commit 06eb4e3

Browse files
GH1248/GH1249 DataFrame.assign and MultiIndex.from_product Series/Index (#1250)
* GH1248/GH1249 DataFrame.assign and MultiIndex.from_product Series/Index * GH1248/GH1249 PR feedback
1 parent b12c28d commit 06eb4e3

File tree

4 files changed

+42
-4
lines changed

4 files changed

+42
-4
lines changed

pandas-stubs/_typing.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,12 @@ TimeZones: TypeAlias = str | tzinfo | None | int
10131013

10141014
# Evaluates to a DataFrame column in DataFrame.assign context.
10151015
IntoColumn: TypeAlias = (
1016-
AnyArrayLike | Scalar | Callable[[DataFrame], AnyArrayLike | Scalar] | None
1016+
AnyArrayLike
1017+
| Scalar
1018+
| Callable[[DataFrame], AnyArrayLike | Scalar | list[Scalar] | range]
1019+
| list[Scalar]
1020+
| range
1021+
| None
10171022
)
10181023

10191024
DatetimeLike: TypeAlias = datetime.datetime | np.datetime64 | Timestamp

pandas-stubs/core/indexes/multi.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class MultiIndex(Index[Any]):
5555
@classmethod
5656
def from_product(
5757
cls,
58-
iterables: Sequence[SequenceNotStr[Hashable]],
58+
iterables: Sequence[SequenceNotStr[Hashable] | pd.Series | pd.Index],
5959
sortorder: int | None = ...,
6060
names: SequenceNotStr[Hashable] = ...,
6161
) -> Self: ...

tests/test_frame.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
)
4646
import xarray as xr
4747

48-
from pandas._typing import Scalar
48+
from pandas._typing import (
49+
Scalar,
50+
)
4951

5052
from tests import (
5153
PD_LTE_23,
@@ -305,9 +307,26 @@ def test_types_head_tail() -> None:
305307

306308
def test_types_assign() -> None:
307309
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
308-
df.assign(col3=lambda frame: frame.sum(axis=1))
310+
311+
check(
312+
assert_type(df.assign(col3=lambda frame: frame.sum(axis=1)), pd.DataFrame),
313+
pd.DataFrame,
314+
)
309315
df["col3"] = df.sum(axis=1)
310316

317+
df = pd.DataFrame({"a": [1, 2, 3]})
318+
check(
319+
assert_type(
320+
df.assign(b=lambda df: range(len(df)), c=lambda _: [10, 20, 30]),
321+
pd.DataFrame,
322+
),
323+
pd.DataFrame,
324+
)
325+
check(
326+
assert_type(df.assign(b=range(len(df)), c=[10, 20, 30]), pd.DataFrame),
327+
pd.DataFrame,
328+
)
329+
311330

312331
def test_assign() -> None:
313332
df = pd.DataFrame({"a": [1, 2, 3], 1: [4, 5, 6]})

tests/test_indexes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,20 @@ def test_index_astype() -> None:
6161
pd.DataFrame,
6262
)
6363

64+
df = pd.DataFrame({"a": [1, 2, 3]})
65+
check(
66+
assert_type(
67+
pd.MultiIndex.from_product([["x", "y"], df.columns]), pd.MultiIndex
68+
),
69+
pd.MultiIndex,
70+
)
71+
check(
72+
assert_type(
73+
pd.MultiIndex.from_product([["x", "y"], pd.Series([1, 2])]), pd.MultiIndex
74+
),
75+
pd.MultiIndex,
76+
)
77+
6478

6579
def test_multiindex_get_level_values() -> None:
6680
mi = pd.MultiIndex.from_product([["a", "b"], ["c", "d"]], names=["ab", "cd"])

0 commit comments

Comments
 (0)