Skip to content

Commit 5191ff4

Browse files
committed
.
1 parent 27927b5 commit 5191ff4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_take_along_axis(x, data):
7373
# TODO
7474
# 2. negative indices
7575
# 3. different dtypes for indices
76+
# 4. "broadcast-compatible" indices
7677
axis = data.draw(
7778
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
7879
label="axis"
@@ -84,8 +85,8 @@ def test_take_along_axis(x, data):
8485
axis_kw = {"axis": axis}
8586
n_axis = axis + x.ndim if axis < 0 else axis
8687

87-
len_axis = data.draw(st.integers(0, 2*x.shape[n_axis]), label="len_axis")
88-
idx_shape = x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:]
88+
new_len = data.draw(st.integers(0, 2*x.shape[n_axis]), label="new_len")
89+
idx_shape = x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:]
8990
indices = data.draw(
9091
hh.arrays(
9192
shape=idx_shape,
@@ -102,7 +103,7 @@ def test_take_along_axis(x, data):
102103
ph.assert_shape(
103104
"take_along_axis",
104105
out_shape=out.shape,
105-
expected=x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:],
106+
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
106107
kw=dict(
107108
x=x,
108109
indices=indices,
@@ -117,5 +118,5 @@ def test_take_along_axis(x, data):
117118
a_1d = x[ii + (slice(None),) + kk]
118119
i_1d = indices[ii + (slice(None),) + kk]
119120
o_1d = out[ii + (slice(None),) + kk]
120-
for j in range(len_axis):
121+
for j in range(new_len):
121122
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'

0 commit comments

Comments
 (0)