Skip to content

Commit b10aea5

Browse files
asmeurerev-br
authored andcommitted
Fix test_take to make axis optional when ndim == 1
I didn't explicitly test axis=None because it's not clear to me that should actually be supported, given that that's the same as axis=0.
1 parent 8c8cb69 commit b10aea5

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,22 @@ def test_take(x, data):
2020
# * negative axis
2121
# * negative indices
2222
# * different dtypes for indices
23-
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
23+
24+
# axis is optional but only if x.ndim == 1
25+
_axis_st = st.integers(0, max(x.ndim - 1, 0))
26+
if x.ndim == 1:
27+
kw = data.draw(hh.kwargs(axis=_axis_st))
28+
else:
29+
kw = {"axis": data.draw(_axis_st)}
30+
axis = kw.get("axis", 0)
2431
_indices = data.draw(
2532
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
2633
label="_indices",
2734
)
2835
indices = xp.asarray(_indices, dtype=dh.default_int)
2936
note(f"{indices=}")
3037

31-
out = xp.take(x, indices, axis=axis)
38+
out = xp.take(x, indices, **kw)
3239

3340
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
3441
ph.assert_shape(

0 commit comments

Comments
 (0)