Skip to content

Commit 0f2b873

Browse files
authored
Fix (#73270)
1 parent 04de74f commit 0f2b873

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

test/legacy_test/test_index_select_op.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,55 @@ def init_dtype_type(self):
9292
self.index_size = 10
9393

9494

95+
class TestIndexSelectOp_ZeroSize(OpTest):
96+
def setUp(self):
97+
self.python_api = paddle.index_select
98+
self.public_python_api = paddle.index_select
99+
self.op_type = "index_select"
100+
self.init_dtype_type()
101+
102+
index_np = np.random.randint(
103+
low=-self.x_shape[self.dim],
104+
high=self.x_shape[self.dim],
105+
size=self.index_size,
106+
)
107+
x_np = np.random.random(self.x_shape).astype(self.x_type)
108+
if self.dtype == np.complex64 or self.dtype == np.complex128:
109+
x_np = (
110+
np.random.random(self.x_shape)
111+
+ 1j * np.random.random(self.x_shape)
112+
).astype(self.x_type)
113+
self.inputs = {'X': x_np, 'Index': index_np}
114+
self.attrs = {'dim': self.dim}
115+
outer_loop = np.prod(self.x_shape[: self.dim])
116+
x_reshape = [outer_loop, *self.x_shape[self.dim :]]
117+
x_np_reshape = np.reshape(x_np, tuple(x_reshape))
118+
out_list = []
119+
for i in range(outer_loop):
120+
for j in range(self.index_size):
121+
out_list.append(x_np_reshape[i, index_np[j]])
122+
self.out_shape = list(self.x_shape)
123+
self.out_shape[self.dim] = self.index_size
124+
self.out_shape = tuple(self.out_shape)
125+
126+
out = np.reshape(out_list, self.out_shape)
127+
self.outputs = {'Out': out}
128+
129+
def test_check_output(self):
130+
self.check_output(check_pir=True)
131+
132+
def test_check_grad_normal(self):
133+
self.check_grad(['X'], 'Out', check_pir=True)
134+
135+
def init_dtype_type(self):
136+
self.x_type = np.float64
137+
self.index_type = np.int64
138+
self.dim = 1
139+
# shape[dim] can not be 0.
140+
self.x_shape = (0, 10, 0, 0)
141+
self.index_size = 10
142+
143+
95144
class TestIndexSelectOpCaseSingleThread(TestIndexSelectOp):
96145
def init_dtype_type(self):
97146
if base.is_compiled_with_cuda():

0 commit comments

Comments
 (0)