diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index e7a6b207363c3..fefd70fef35c9 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -59,6 +59,7 @@ masked_reductions, ) from pandas.core.array_algos.quantile import quantile_with_mask +from pandas.core.array_algos.transforms import shift from pandas.core.arraylike import OpsMixin from pandas.core.arrays._utils import to_numpy_dtype_inference from pandas.core.arrays.base import ExtensionArray @@ -361,6 +362,17 @@ def ravel(self, *args, **kwargs) -> Self: mask = self._mask.ravel(*args, **kwargs) return type(self)(data, mask) + def shift(self, periods: int = 1, fill_value=None) -> Self: + # NB: shift is always along axis=0 + axis = 0 + if fill_value is None: + new_data = shift(self._data, periods, axis, 0) + new_mask = shift(self._mask, periods, axis, True) + else: + new_data = shift(self._data, periods, axis, fill_value) + new_mask = shift(self._mask, periods, axis, False) + return type(self)(new_data, new_mask) + @property def T(self) -> Self: return self._simple_new(self._data.T, self._mask.T) diff --git a/pandas/tests/extension/base/dim2.py b/pandas/tests/extension/base/dim2.py index 8c7d8ff491cd3..890766acbd610 100644 --- a/pandas/tests/extension/base/dim2.py +++ b/pandas/tests/extension/base/dim2.py @@ -32,6 +32,16 @@ def skip_if_doesnt_support_2d(self, dtype, request): # TODO: is there a less hacky way of checking this? pytest.skip(f"{dtype} does not support 2D.") + def test_shift_2d(self, data): + arr2d = data.repeat(2).reshape(-1, 2) + + for n in [1, -2]: + for fill_value in [None, data[0]]: + result = arr2d.shift(n, fill_value=fill_value) + expected_col = data.shift(n, fill_value=fill_value) + tm.assert_extension_array_equal(result[:, 0], expected_col) + tm.assert_extension_array_equal(result[:, 1], expected_col) + def test_transpose(self, data): arr2d = data.repeat(2).reshape(-1, 2) shape = arr2d.shape diff --git a/pandas/tests/extension/test_masked.py b/pandas/tests/extension/test_masked.py index 3b9079d06e231..c7fe9e99ec6e5 100644 --- a/pandas/tests/extension/test_masked.py +++ b/pandas/tests/extension/test_masked.py @@ -168,6 +168,13 @@ def data_for_grouping(dtype): class TestMaskedArrays(base.ExtensionTests): + @pytest.fixture(autouse=True) + def skip_if_doesnt_support_2d(self, dtype, request): + # Override the fixture so that we run these tests. + assert not dtype._supports_2d + # If dtype._supports_2d is ever changed to True, then this fixture + # override becomes unnecessary. + @pytest.mark.parametrize("na_action", [None, "ignore"]) def test_map(self, data_missing, na_action): result = data_missing.map(lambda x: x, na_action=na_action) @@ -402,7 +409,3 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool): else: raise NotImplementedError(f"{op_name} not supported") - - -class Test2DCompat(base.Dim2CompatTests): - pass