Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SimpleOpBuilder : public BaseOpBuilder {

static constexpr std::array<std::string_view, 2> gridsample_supported_modes = {"bilinear", "nearest"};
static constexpr std::array<std::string_view, 3> gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
static constexpr std::array<std::string_view, 3> scatternd_supported_reduction = {"none", "add", "mul"};
};

Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper,
Expand Down Expand Up @@ -101,6 +102,14 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper,
}
}

// QNN ScatterND doesn't support MAX, MIN reduction
if (op_type == "ScatterND") {
NodeAttrHelper node_helper(node_unit);
std::string reduction = node_helper.Get("reduction", "none");
ORT_RETURN_IF_NOT(utils::ArrayHasString(scatternd_supported_reduction, reduction), "ScatterND does not support reduction ",
reduction.c_str());
}

return Status::OK();
}

Expand Down Expand Up @@ -254,6 +263,31 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

// Process Reduction attribute of ScatterND op
Status ProcessScatterNDReductionAttribute(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>& param_tensor_names) {
NodeAttrHelper node_helper(node_unit);
std::string reduction = node_helper.Get("reduction", "none");
Qnn_Scalar_t reduction_qnn_scalar = QNN_SCALAR_INIT;
reduction_qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
if ("none" == reduction) {
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_NONE;
} else if ("add" == reduction) {
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_ADD;
} else if ("mul" == reduction) {
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_MUL;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ScatterND support only reduction:{none, add, mul}.");
}
QnnParamWrapper reduction_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ND_PARAM_REDUCTION,
reduction_qnn_scalar);
param_tensor_names.push_back(reduction_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(reduction_param));

return Status::OK();
}

Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand Down Expand Up @@ -358,6 +392,11 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names));
}

if (op_type == "ScatterND") {
// Process reduction attribute
ORT_RETURN_IF_ERROR(ProcessScatterNDReductionAttribute(qnn_model_wrapper, node_unit, param_tensor_names));
}

return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
Expand Down
72 changes: 72 additions & 0 deletions onnxruntime/test/providers/qnn/simple_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,78 @@ TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) {
ExpectedEPNodeAssignment::All);
}

// Test ScatterND with reduction ADD on HTP
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_add) {
std::vector<int64_t> data = {0, 1, 2, 3};
std::vector<int64_t> indices = {1};
std::vector<int64_t> updates = {10};
RunOpTest<int64_t>("ScatterND",
{
TestInputDef<int64_t>({4}, false, std::move(data)),
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
TestInputDef<int64_t>({1}, false, std::move(updates)),
},
{
utils::MakeAttribute("reduction", "add"),
},
17,
ExpectedEPNodeAssignment::All);
}

// Test ScatterND with reduction Mul on HTP
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_mul) {
std::vector<int64_t> data = {0, 1, 2, 3};
std::vector<int64_t> indices = {1};
std::vector<int64_t> updates = {10};
RunOpTest<int64_t>("ScatterND",
{
TestInputDef<int64_t>({4}, false, std::move(data)),
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
TestInputDef<int64_t>({1}, false, std::move(updates)),
},
{
utils::MakeAttribute("reduction", "mul"),
},
17,
ExpectedEPNodeAssignment::All);
}

// Test ScatterND with reduction Max on CPU Fallback
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_max) {
std::vector<int64_t> data = {0, 1, 2, 3};
std::vector<int64_t> indices = {1};
std::vector<int64_t> updates = {10};
RunOpTest<int64_t>("ScatterND",
{
TestInputDef<int64_t>({4}, false, std::move(data)),
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
TestInputDef<int64_t>({1}, false, std::move(updates)),
},
{
utils::MakeAttribute("reduction", "max"),
},
17,
ExpectedEPNodeAssignment::None);
}

// Test ScatterND with reduction Min on CPU Fallback
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_min) {
std::vector<int64_t> data = {0, 1, 2, 3};
std::vector<int64_t> indices = {1};
std::vector<int64_t> updates = {10};
RunOpTest<int64_t>("ScatterND",
{
TestInputDef<int64_t>({4}, false, std::move(data)),
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
TestInputDef<int64_t>({1}, false, std::move(updates)),
},
{
utils::MakeAttribute("reduction", "min"),
},
17,
ExpectedEPNodeAssignment::None);
}

// Test 8-bit QDQ GridSample with bilinear
TEST_F(QnnHTPBackendTests, GridSample_Bilinear) {
RunQDQOpTest<uint8_t>("GridSample",
Expand Down
Loading