Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2366,13 +2366,27 @@ bool NonzeroOpInferSymbolicShape(
common::errors::InvalidArgument(
"Input(x) should have number of dimension at least 1."));

std::string sym_name = infer_context->GetNextSymName();
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{sym_name},
symbol::DimExpr{rank}};

symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(out_shape)};
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
bool zero = 0;
for (int i = 0; i < rank; i++) {
if (x_shape[i] == 0) {
zero = 1;
break;
}
}
if (zero) {
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{0},
symbol::DimExpr{rank}};
symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(out_shape)};
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
} else {
std::string sym_name = infer_context->GetNextSymName();
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{sym_name},
symbol::DimExpr{rank}};
symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(out_shape)};
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
}
return true;
}

Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2873,7 +2873,12 @@ void NonZeroInferMeta(const MetaTensor& condition, MetaTensor* out) {
1UL,
common::errors::InvalidArgument(
"Input(Condition) should have number of dimension at least 1"));
out->set_dims(common::make_ddim({-1, rank}));
if (condition.numel() == 0) {
out->set_dims(common::make_ddim({0, rank}));
} else {
out->set_dims(common::make_ddim({-1, rank}));
}

out->set_dtype(DataType::INT64);
}

Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/cpu/nonzero_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ void NonZeroKernel(const Context& dev_ctx,
auto dims = condition.dims();
const int rank = dims.size();

if (numel == 0) {
dev_ctx.template Alloc<T>(out);
return;
}

std::vector<int64_t> true_index;
for (auto i = 0; i < numel; i++) {
if (static_cast<bool>(cond_data[i])) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/gpu/nonzero_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ template <typename T, typename Context>
void NonZeroKernel(const Context &dev_ctx,
const DenseTensor &condition,
DenseTensor *out) {
if (condition.numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
DenseTensor in_data;
auto dims = condition.dims();
using Functor = IndexFunctor<T, int64_t, int64_t>;
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/xpu/nonzero_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ void NonZeroKernel(const Context& dev_ctx,

using XPUType = typename XPUTypeTrait<T>::Type;

if (numel == 0) {
dev_ctx.template Alloc<T>(out);
return;
}

xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int64_t* true_num = RAII_GUARD.alloc_l3_or_gm<int64_t>(1);
int64_t* workspace =
Expand Down
33 changes: 33 additions & 0 deletions test/legacy_test/test_nonzero_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def test_nonzero_api_as_tuple(self):
expect_out = np.array([0, 1])
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

data = np.zeros([10, 3, 0], dtype="float32")
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[10, 3, 0], dtype='float32')
if not paddle.framework.use_pir_api():
x.desc.set_need_check_feed(False)
y = paddle.nonzero(x, as_tuple=True)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 3)
expect_out = np.zeros([0])
for item in y:
np.testing.assert_array_equal(expect_out, item)

def test_nonzero_api(self):
paddle.enable_static()
data = np.array([[1, 0], [0, 1]], dtype="float32")
Expand Down Expand Up @@ -181,5 +193,26 @@ def return_outputs(self):
return {'Out': np.transpose(np.nonzero(self.inputs['Condition']))}


class TestZeroSizeOp(TestNonzeroOp):

def init_shape(self):
self.shape = [0, 10]

def init_dtype(self):
self.dtype = np.float64


class TestZeroSizeOpCase2(TestNonzeroOp):

def init_shape(self):
self.shape = [0, 10]

def init_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output(check_pir=True, check_symbol_infer=True)


if __name__ == "__main__":
unittest.main()
Loading