Skip to content

Commit e1328fc

Browse files
authored
TST: enable 2D tests for MaskedArrays, fix+test shift (#61826)
1 parent 688e2a0 commit e1328fc

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

pandas/core/arrays/masked.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
masked_reductions,
6060
)
6161
from pandas.core.array_algos.quantile import quantile_with_mask
62+
from pandas.core.array_algos.transforms import shift
6263
from pandas.core.arraylike import OpsMixin
6364
from pandas.core.arrays._utils import to_numpy_dtype_inference
6465
from pandas.core.arrays.base import ExtensionArray
@@ -361,6 +362,17 @@ def ravel(self, *args, **kwargs) -> Self:
361362
mask = self._mask.ravel(*args, **kwargs)
362363
return type(self)(data, mask)
363364

365+
def shift(self, periods: int = 1, fill_value=None) -> Self:
366+
# NB: shift is always along axis=0
367+
axis = 0
368+
if fill_value is None:
369+
new_data = shift(self._data, periods, axis, 0)
370+
new_mask = shift(self._mask, periods, axis, True)
371+
else:
372+
new_data = shift(self._data, periods, axis, fill_value)
373+
new_mask = shift(self._mask, periods, axis, False)
374+
return type(self)(new_data, new_mask)
375+
364376
@property
365377
def T(self) -> Self:
366378
return self._simple_new(self._data.T, self._mask.T)

pandas/tests/extension/base/dim2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ def skip_if_doesnt_support_2d(self, dtype, request):
3232
# TODO: is there a less hacky way of checking this?
3333
pytest.skip(f"{dtype} does not support 2D.")
3434

35+
def test_shift_2d(self, data):
36+
arr2d = data.repeat(2).reshape(-1, 2)
37+
38+
for n in [1, -2]:
39+
for fill_value in [None, data[0]]:
40+
result = arr2d.shift(n, fill_value=fill_value)
41+
expected_col = data.shift(n, fill_value=fill_value)
42+
tm.assert_extension_array_equal(result[:, 0], expected_col)
43+
tm.assert_extension_array_equal(result[:, 1], expected_col)
44+
3545
def test_transpose(self, data):
3646
arr2d = data.repeat(2).reshape(-1, 2)
3747
shape = arr2d.shape

pandas/tests/extension/test_masked.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ def data_for_grouping(dtype):
168168

169169

170170
class TestMaskedArrays(base.ExtensionTests):
171+
@pytest.fixture(autouse=True)
172+
def skip_if_doesnt_support_2d(self, dtype, request):
173+
# Override the fixture so that we run these tests.
174+
assert not dtype._supports_2d
175+
# If dtype._supports_2d is ever changed to True, then this fixture
176+
# override becomes unnecessary.
177+
171178
@pytest.mark.parametrize("na_action", [None, "ignore"])
172179
def test_map(self, data_missing, na_action):
173180
result = data_missing.map(lambda x: x, na_action=na_action)
@@ -402,7 +409,3 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
402409

403410
else:
404411
raise NotImplementedError(f"{op_name} not supported")
405-
406-
407-
class Test2DCompat(base.Dim2CompatTests):
408-
pass

0 commit comments

Comments
 (0)