Skip to content

Commit 625289c

Browse files
[QNN EP] Add ScatterND reduction attribute (#24844)
### Description - Add support for ScatterND reduction attribute - Gracefully handle the unsupported reduction values - Add unit tests to validate Reduction attribute support ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 3a20910 commit 625289c

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class SimpleOpBuilder : public BaseOpBuilder {
4040

4141
static constexpr std::array<std::string_view, 2> gridsample_supported_modes = {"bilinear", "nearest"};
4242
static constexpr std::array<std::string_view, 3> gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
43+
static constexpr std::array<std::string_view, 3> scatternd_supported_reduction = {"none", "add", "mul"};
4344
};
4445

4546
Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper,
@@ -101,6 +102,14 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper,
101102
}
102103
}
103104

105+
// QNN ScatterND doesn't support MAX, MIN reduction
106+
if (op_type == "ScatterND") {
107+
NodeAttrHelper node_helper(node_unit);
108+
std::string reduction = node_helper.Get("reduction", "none");
109+
ORT_RETURN_IF_NOT(utils::ArrayHasString(scatternd_supported_reduction, reduction), "ScatterND does not support reduction ",
110+
reduction.c_str());
111+
}
112+
104113
return Status::OK();
105114
}
106115

@@ -254,6 +263,31 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper,
254263
return Status::OK();
255264
}
256265

266+
// Process Reduction attribute of ScatterND op
267+
Status ProcessScatterNDReductionAttribute(QnnModelWrapper& qnn_model_wrapper,
268+
const NodeUnit& node_unit,
269+
std::vector<std::string>& param_tensor_names) {
270+
NodeAttrHelper node_helper(node_unit);
271+
std::string reduction = node_helper.Get("reduction", "none");
272+
Qnn_Scalar_t reduction_qnn_scalar = QNN_SCALAR_INIT;
273+
reduction_qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
274+
if ("none" == reduction) {
275+
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_NONE;
276+
} else if ("add" == reduction) {
277+
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_ADD;
278+
} else if ("mul" == reduction) {
279+
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_MUL;
280+
} else {
281+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ScatterND support only reduction:{none, add, mul}.");
282+
}
283+
QnnParamWrapper reduction_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ND_PARAM_REDUCTION,
284+
reduction_qnn_scalar);
285+
param_tensor_names.push_back(reduction_param.GetParamTensorName());
286+
qnn_model_wrapper.AddParamWrapper(std::move(reduction_param));
287+
288+
return Status::OK();
289+
}
290+
257291
Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
258292
const NodeUnit& node_unit,
259293
std::vector<std::string>&& input_names,
@@ -358,6 +392,11 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
358392
ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names));
359393
}
360394

395+
if (op_type == "ScatterND") {
396+
// Process reduction attribute
397+
ORT_RETURN_IF_ERROR(ProcessScatterNDReductionAttribute(qnn_model_wrapper, node_unit, param_tensor_names));
398+
}
399+
361400
return ProcessOutputs(qnn_model_wrapper, node_unit,
362401
std::move(input_names),
363402
std::move(param_tensor_names),

onnxruntime/test/providers/qnn/simple_op_htp_test.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,78 @@ TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) {
10171017
ExpectedEPNodeAssignment::All);
10181018
}
10191019

1020+
// Test ScatterND with reduction ADD on HTP
1021+
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_add) {
1022+
std::vector<int64_t> data = {0, 1, 2, 3};
1023+
std::vector<int64_t> indices = {1};
1024+
std::vector<int64_t> updates = {10};
1025+
RunOpTest<int64_t>("ScatterND",
1026+
{
1027+
TestInputDef<int64_t>({4}, false, std::move(data)),
1028+
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
1029+
TestInputDef<int64_t>({1}, false, std::move(updates)),
1030+
},
1031+
{
1032+
utils::MakeAttribute("reduction", "add"),
1033+
},
1034+
17,
1035+
ExpectedEPNodeAssignment::All);
1036+
}
1037+
1038+
// Test ScatterND with reduction Mul on HTP
1039+
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_mul) {
1040+
std::vector<int64_t> data = {0, 1, 2, 3};
1041+
std::vector<int64_t> indices = {1};
1042+
std::vector<int64_t> updates = {10};
1043+
RunOpTest<int64_t>("ScatterND",
1044+
{
1045+
TestInputDef<int64_t>({4}, false, std::move(data)),
1046+
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
1047+
TestInputDef<int64_t>({1}, false, std::move(updates)),
1048+
},
1049+
{
1050+
utils::MakeAttribute("reduction", "mul"),
1051+
},
1052+
17,
1053+
ExpectedEPNodeAssignment::All);
1054+
}
1055+
1056+
// Test ScatterND with reduction Max on CPU Fallback
1057+
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_max) {
1058+
std::vector<int64_t> data = {0, 1, 2, 3};
1059+
std::vector<int64_t> indices = {1};
1060+
std::vector<int64_t> updates = {10};
1061+
RunOpTest<int64_t>("ScatterND",
1062+
{
1063+
TestInputDef<int64_t>({4}, false, std::move(data)),
1064+
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
1065+
TestInputDef<int64_t>({1}, false, std::move(updates)),
1066+
},
1067+
{
1068+
utils::MakeAttribute("reduction", "max"),
1069+
},
1070+
17,
1071+
ExpectedEPNodeAssignment::None);
1072+
}
1073+
1074+
// Test ScatterND with reduction Min on CPU Fallback
1075+
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_min) {
1076+
std::vector<int64_t> data = {0, 1, 2, 3};
1077+
std::vector<int64_t> indices = {1};
1078+
std::vector<int64_t> updates = {10};
1079+
RunOpTest<int64_t>("ScatterND",
1080+
{
1081+
TestInputDef<int64_t>({4}, false, std::move(data)),
1082+
TestInputDef<int64_t>({1, 1}, false, std::move(indices)),
1083+
TestInputDef<int64_t>({1}, false, std::move(updates)),
1084+
},
1085+
{
1086+
utils::MakeAttribute("reduction", "min"),
1087+
},
1088+
17,
1089+
ExpectedEPNodeAssignment::None);
1090+
}
1091+
10201092
// Test 8-bit QDQ GridSample with bilinear
10211093
TEST_F(QnnHTPBackendTests, GridSample_Bilinear) {
10221094
RunQDQOpTest<uint8_t>("GridSample",

0 commit comments

Comments
 (0)