Skip to content

Commit fa1215e

Browse files
Revert "[BIT] nonzero (#72244) (#73010)" (#73265)
This reverts commit 2d085d2.
1 parent 0a5008b commit fa1215e

File tree

6 files changed

+8
-74
lines changed

6 files changed

+8
-74
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,27 +2366,13 @@ bool NonzeroOpInferSymbolicShape(
23662366
common::errors::InvalidArgument(
23672367
"Input(x) should have number of dimension at least 1."));
23682368

2369-
bool zero = 0;
2370-
for (int i = 0; i < rank; i++) {
2371-
if (x_shape[i] == 0) {
2372-
zero = 1;
2373-
break;
2374-
}
2375-
}
2376-
if (zero) {
2377-
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{0},
2378-
symbol::DimExpr{rank}};
2379-
symbol::ShapeOrDataDimExprs shape_data{
2380-
symbol::TensorShapeOrDataDimExprs(out_shape)};
2381-
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
2382-
} else {
2383-
std::string sym_name = infer_context->GetNextSymName();
2384-
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{sym_name},
2385-
symbol::DimExpr{rank}};
2386-
symbol::ShapeOrDataDimExprs shape_data{
2387-
symbol::TensorShapeOrDataDimExprs(out_shape)};
2388-
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
2389-
}
2369+
std::string sym_name = infer_context->GetNextSymName();
2370+
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{sym_name},
2371+
symbol::DimExpr{rank}};
2372+
2373+
symbol::ShapeOrDataDimExprs shape_data{
2374+
symbol::TensorShapeOrDataDimExprs(out_shape)};
2375+
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
23902376
return true;
23912377
}
23922378

paddle/phi/infermeta/unary.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,12 +2873,7 @@ void NonZeroInferMeta(const MetaTensor& condition, MetaTensor* out) {
28732873
1UL,
28742874
common::errors::InvalidArgument(
28752875
"Input(Condition) should have number of dimension at least 1"));
2876-
if (condition.numel() == 0) {
2877-
out->set_dims(common::make_ddim({0, rank}));
2878-
} else {
2879-
out->set_dims(common::make_ddim({-1, rank}));
2880-
}
2881-
2876+
out->set_dims(common::make_ddim({-1, rank}));
28822877
out->set_dtype(DataType::INT64);
28832878
}
28842879

paddle/phi/kernels/cpu/nonzero_kernel.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ void NonZeroKernel(const Context& dev_ctx,
5555
auto dims = condition.dims();
5656
const int rank = dims.size();
5757

58-
if (numel == 0) {
59-
dev_ctx.template Alloc<T>(out);
60-
return;
61-
}
62-
6358
std::vector<int64_t> true_index;
6459
for (auto i = 0; i < numel; i++) {
6560
if (static_cast<bool>(cond_data[i])) {

paddle/phi/kernels/gpu/nonzero_kernel.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ template <typename T, typename Context>
6565
void NonZeroKernel(const Context &dev_ctx,
6666
const DenseTensor &condition,
6767
DenseTensor *out) {
68-
if (condition.numel() == 0) {
69-
dev_ctx.template Alloc<T>(out);
70-
return;
71-
}
7268
DenseTensor in_data;
7369
auto dims = condition.dims();
7470
using Functor = IndexFunctor<T, int64_t, int64_t>;

paddle/phi/kernels/xpu/nonzero_kernel.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ void NonZeroKernel(const Context& dev_ctx,
3232

3333
using XPUType = typename XPUTypeTrait<T>::Type;
3434

35-
if (numel == 0) {
36-
dev_ctx.template Alloc<T>(out);
37-
return;
38-
}
39-
4035
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
4136
int64_t* true_num = RAII_GUARD.alloc_l3_or_gm<int64_t>(1);
4237
int64_t* workspace =

test/legacy_test/test_nonzero_api.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,6 @@ def test_nonzero_api_as_tuple(self):
6363
expect_out = np.array([0, 1])
6464
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
6565

66-
data = np.zeros([10, 3, 0], dtype="float32")
67-
with program_guard(Program(), Program()):
68-
x = paddle.static.data(name='x', shape=[10, 3, 0], dtype='float32')
69-
if not paddle.framework.use_pir_api():
70-
x.desc.set_need_check_feed(False)
71-
y = paddle.nonzero(x, as_tuple=True)
72-
self.assertEqual(type(y), tuple)
73-
self.assertEqual(len(y), 3)
74-
expect_out = np.zeros([0])
75-
for item in y:
76-
np.testing.assert_array_equal(expect_out, item)
77-
7866
def test_nonzero_api(self):
7967
paddle.enable_static()
8068
data = np.array([[1, 0], [0, 1]], dtype="float32")
@@ -193,26 +181,5 @@ def return_outputs(self):
193181
return {'Out': np.transpose(np.nonzero(self.inputs['Condition']))}
194182

195183

196-
class TestZeroSizeOp(TestNonzeroOp):
197-
198-
def init_shape(self):
199-
self.shape = [0, 10]
200-
201-
def init_dtype(self):
202-
self.dtype = np.float64
203-
204-
205-
class TestZeroSizeOpCase2(TestNonzeroOp):
206-
207-
def init_shape(self):
208-
self.shape = [0, 10]
209-
210-
def init_dtype(self):
211-
self.dtype = np.float64
212-
213-
def test_check_output(self):
214-
self.check_output(check_pir=True, check_symbol_infer=True)
215-
216-
217184
if __name__ == "__main__":
218185
unittest.main()

0 commit comments

Comments
 (0)