@@ -73,6 +73,7 @@ def test_take_along_axis(x, data):
73
73
# TODO
74
74
# 2. negative indices
75
75
# 3. different dtypes for indices
76
+ # 4. "broadcast-compatible" indices
76
77
axis = data .draw (
77
78
st .integers (- x .ndim , max (x .ndim - 1 , 0 )) | st .none (),
78
79
label = "axis"
@@ -84,8 +85,8 @@ def test_take_along_axis(x, data):
84
85
axis_kw = {"axis" : axis }
85
86
n_axis = axis + x .ndim if axis < 0 else axis
86
87
87
- len_axis = data .draw (st .integers (0 , 2 * x .shape [n_axis ]), label = "len_axis " )
88
- idx_shape = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :]
88
+ new_len = data .draw (st .integers (0 , 2 * x .shape [n_axis ]), label = "new_len " )
89
+ idx_shape = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :]
89
90
indices = data .draw (
90
91
hh .arrays (
91
92
shape = idx_shape ,
@@ -102,7 +103,7 @@ def test_take_along_axis(x, data):
102
103
ph .assert_shape (
103
104
"take_along_axis" ,
104
105
out_shape = out .shape ,
105
- expected = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :],
106
+ expected = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :],
106
107
kw = dict (
107
108
x = x ,
108
109
indices = indices ,
@@ -117,5 +118,5 @@ def test_take_along_axis(x, data):
117
118
a_1d = x [ii + (slice (None ),) + kk ]
118
119
i_1d = indices [ii + (slice (None ),) + kk ]
119
120
o_1d = out [ii + (slice (None ),) + kk ]
120
- for j in range (len_axis ):
121
+ for j in range (new_len ):
121
122
assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
0 commit comments