diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index b8dc94448..297e94057 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -260,7 +260,7 @@ class Index(IndexOpsMixin[S1]): **kwargs, ) -> Self: ... @property - def str(self) -> StringMethods[Self, MultiIndex]: ... + def str(self) -> StringMethods[Self, MultiIndex, np_ndarray_bool]: ... def is_(self, other) -> bool: ... def __len__(self) -> int: ... def __array__(self, dtype=...) -> np.ndarray: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 209237746..e12fa2cc7 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1163,7 +1163,7 @@ class Series(IndexOpsMixin[S1], NDFrame): ) -> Series[S1]: ... def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ... @property - def str(self) -> StringMethods[Series, DataFrame]: ... + def str(self) -> StringMethods[Series, DataFrame, Series[bool]]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index a21074dad..a3596aa5c 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -23,12 +23,15 @@ from pandas.core.base import NoNewAttributesMixin from pandas._typing import ( JoinHow, T, + np_ndarray_bool, ) # The _TS type is what is used for the result of str.split with expand=True _TS = TypeVar("_TS", DataFrame, MultiIndex) +# The _TM type is what is used for the result of str.match +_TM = TypeVar("_TM", Series[bool], np_ndarray_bool) -class StringMethods(NoNewAttributesMixin, Generic[T, _TS]): +class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM]): def __init__(self, data: T) -> None: ... def __getitem__(self, key: slice | int) -> T: ... def __iter__(self) -> T: ... @@ -100,7 +103,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS]): ) -> Series[bool]: ... def match( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... - ) -> T: ... + ) -> _TM: ... def replace( self, pat: str, diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 85edc3b57..6777736d9 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -113,6 +113,13 @@ def test_str_split() -> None: check(assert_type(ind.str.split("-", expand=True), pd.MultiIndex), pd.MultiIndex) +def test_str_match() -> None: + i = pd.Index( + ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + ) + check(assert_type(i.str.match("pp"), npt.NDArray[np.bool_]), np.ndarray, np.bool_) + + def test_index_rename() -> None: ind = pd.Index([1, 2, 3], name="foo") ind2 = ind.rename("goo") diff --git a/tests/test_series.py b/tests/test_series.py index e68028cc4..bcac374ed 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1481,7 +1481,7 @@ def test_string_accessors(): check(assert_type(s.str.ljust(80), pd.Series), pd.Series) check(assert_type(s.str.lower(), pd.Series), pd.Series) check(assert_type(s.str.lstrip("a"), pd.Series), pd.Series) - check(assert_type(s.str.match("pp"), pd.Series), pd.Series) + check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(s.str.normalize("NFD"), pd.Series), pd.Series) check(assert_type(s.str.pad(80, "right"), pd.Series), pd.Series) check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame)