Skip to content

Commit 5802c53

Browse files
HonrynaomiOvad
authored andcommitted
[WebNN] Fix some issues in reduction ops (microsoft#26289)
- Allow empty axes input - When axes is empty and ‘noop_with_empty_axes’ is true, WebNN should set axes to [] - Simplify the code
1 parent 1101682 commit 5802c53

File tree

3 files changed

+105
-109
lines changed

3 files changed

+105
-109
lines changed

js/web/test/suite-test-list.jsonc

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,66 +2147,66 @@
21472147
"test_reduce_log_sum_default",
21482148
"test_reduce_log_sum_desc_axes",
21492149
// tests "test_reduce_log_sum_exp_*" on opset17/opset18 are excluded because they use float64.
2150-
// "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example",
2151-
// "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random",
2152-
// "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_example",
2153-
// "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_random",
2154-
// "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_example",
2155-
// "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random",
2156-
// "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example",
2157-
// "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random",
2150+
"opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example",
2151+
"opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random",
2152+
"opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_example",
2153+
"opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_random",
2154+
"opset{7,8,9}/test_reduce_log_sum_exp_keepdims_example",
2155+
"opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random",
2156+
"opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example",
2157+
"opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random",
21582158
"test_reduce_log_sum_negative_axes",
21592159
"test_reduce_log_sum",
21602160
"test_reduce_max_default_axes_keepdim_example",
2161-
// "test_reduce_max_default_axes_keepdims_random",
2162-
// "test_reduce_max_do_not_keepdims_example",
2163-
// "test_reduce_max_do_not_keepdims_random",
2164-
// "test_reduce_max_keepdims_example",
2165-
// "test_reduce_max_keepdims_random",
2166-
// "test_reduce_max_negative_axes_keepdims_example",
2167-
// "test_reduce_max_negative_axes_keepdims_random",
2168-
// "test_reduce_mean_default_axes_keepdims_example",
2169-
// "test_reduce_mean_default_axes_keepdims_random",
2170-
// "test_reduce_mean_do_not_keepdims_example",
2171-
// "test_reduce_mean_do_not_keepdims_random",
2172-
// "test_reduce_mean_keepdims_example",
2173-
// "test_reduce_mean_keepdims_random",
2174-
// "test_reduce_mean_negative_axes_keepdims_example",
2175-
// "test_reduce_mean_negative_axes_keepdims_random",
2176-
// "test_reduce_min_default_axes_keepdims_example",
2177-
// "test_reduce_min_default_axes_keepdims_random",
2178-
// "test_reduce_min_do_not_keepdims_example",
2179-
// "test_reduce_min_do_not_keepdims_random",
2180-
// "test_reduce_min_keepdims_example",
2181-
// "test_reduce_min_keepdims_random",
2182-
// "test_reduce_min_negative_axes_keepdims_example",
2183-
// "test_reduce_min_negative_axes_keepdims_random",
2184-
// "test_reduce_prod_default_axes_keepdims_example",
2185-
// "test_reduce_prod_default_axes_keepdims_random",
2186-
// "test_reduce_prod_do_not_keepdims_example",
2187-
// "test_reduce_prod_do_not_keepdims_random",
2188-
// "test_reduce_prod_keepdims_example",
2189-
// "test_reduce_prod_keepdims_random",
2190-
// "test_reduce_prod_negative_axes_keepdims_example",
2191-
// "test_reduce_prod_negative_axes_keepdims_random",
2192-
// "test_reduce_sum_default_axes_keepdims_example",
2193-
// "test_reduce_sum_default_axes_keepdims_random",
2194-
// "test_reduce_sum_do_not_keepdims_example",
2195-
// "test_reduce_sum_do_not_keepdims_random",
2161+
"test_reduce_max_default_axes_keepdims_random",
2162+
"test_reduce_max_do_not_keepdims_example",
2163+
"test_reduce_max_do_not_keepdims_random",
2164+
"test_reduce_max_keepdims_example",
2165+
"test_reduce_max_keepdims_random",
2166+
"test_reduce_max_negative_axes_keepdims_example",
2167+
"test_reduce_max_negative_axes_keepdims_random",
2168+
"test_reduce_mean_default_axes_keepdims_example",
2169+
"test_reduce_mean_default_axes_keepdims_random",
2170+
"test_reduce_mean_do_not_keepdims_example",
2171+
"test_reduce_mean_do_not_keepdims_random",
2172+
"test_reduce_mean_keepdims_example",
2173+
"test_reduce_mean_keepdims_random",
2174+
"test_reduce_mean_negative_axes_keepdims_example",
2175+
"test_reduce_mean_negative_axes_keepdims_random",
2176+
"test_reduce_min_default_axes_keepdims_example",
2177+
"test_reduce_min_default_axes_keepdims_random",
2178+
"test_reduce_min_do_not_keepdims_example",
2179+
"test_reduce_min_do_not_keepdims_random",
2180+
"test_reduce_min_keepdims_example",
2181+
"test_reduce_min_keepdims_random",
2182+
"test_reduce_min_negative_axes_keepdims_example",
2183+
"test_reduce_min_negative_axes_keepdims_random",
2184+
"test_reduce_prod_default_axes_keepdims_example",
2185+
"test_reduce_prod_default_axes_keepdims_random",
2186+
"test_reduce_prod_do_not_keepdims_example",
2187+
"test_reduce_prod_do_not_keepdims_random",
2188+
"test_reduce_prod_keepdims_example",
2189+
"test_reduce_prod_keepdims_random",
2190+
"test_reduce_prod_negative_axes_keepdims_example",
2191+
"test_reduce_prod_negative_axes_keepdims_random",
2192+
"test_reduce_sum_default_axes_keepdims_example",
2193+
"test_reduce_sum_default_axes_keepdims_random",
2194+
"test_reduce_sum_do_not_keepdims_example",
2195+
"test_reduce_sum_do_not_keepdims_random",
21962196
"test_reduce_sum_empty_axes_input_noop_example",
21972197
"test_reduce_sum_empty_axes_input_noop_random",
2198-
// "test_reduce_sum_keepdims_example",
2199-
// "test_reduce_sum_keepdims_random",
2200-
// "test_reduce_sum_negative_axes_keepdims_example",
2201-
// "test_reduce_sum_negative_axes_keepdims_random",
2202-
// "test_reduce_sum_square_default_axes_keepdims_example",
2203-
// "test_reduce_sum_square_default_axes_keepdims_random",
2204-
// "test_reduce_sum_square_do_not_keepdims_example",
2205-
// "test_reduce_sum_square_do_not_keepdims_random",
2206-
// "test_reduce_sum_square_keepdims_example",
2207-
// "test_reduce_sum_square_keepdims_random",
2208-
// "test_reduce_sum_square_negative_axes_keepdims_example",
2209-
// "test_reduce_sum_square_negative_axes_keepdims_random",
2198+
"test_reduce_sum_keepdims_example",
2199+
"test_reduce_sum_keepdims_random",
2200+
"test_reduce_sum_negative_axes_keepdims_example",
2201+
"test_reduce_sum_negative_axes_keepdims_random",
2202+
"test_reduce_sum_square_default_axes_keepdims_example",
2203+
"test_reduce_sum_square_default_axes_keepdims_random",
2204+
"test_reduce_sum_square_do_not_keepdims_example",
2205+
"test_reduce_sum_square_do_not_keepdims_random",
2206+
"test_reduce_sum_square_keepdims_example",
2207+
"test_reduce_sum_square_keepdims_random",
2208+
"test_reduce_sum_square_negative_axes_keepdims_example",
2209+
"test_reduce_sum_square_negative_axes_keepdims_random",
22102210
// "test_reflect_pad",
22112211
"test_relu",
22122212
"test_reshape_allowzero_reordered",

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type);
3838
// Collects all the initializer tensors in the subGraph and its ancestor graphs.
3939
InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer);
4040

41-
inline std::vector<int64_t> HandleNegativeAxes(const std::vector<int64_t>& axes, size_t input_size) {
41+
inline std::vector<int64_t> HandleNegativeAxes(const gsl::span<const int64_t> axes, size_t input_size) {
4242
std::vector<int64_t> new_axes(axes.size());
4343
for (size_t i = 0; i < axes.size(); ++i) {
4444
new_axes[i] = HandleNegativeAxis(axes[i], input_size);

onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc

Lines changed: 49 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ namespace webnn {
1919
class ReductionOpBuilder : public BaseOpBuilder {
2020
// Add operator related.
2121
public:
22+
// Allow axes potentially being empty inputs that are ignored during processing.
23+
ReductionOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {}
2224
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
2325

2426
// Add operator related.
@@ -37,6 +39,7 @@ void ReductionOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons
3739
const auto& input_defs = node.InputDefs();
3840
if (input_defs.size() > 1) {
3941
model_builder.AddInitializerToSkip(input_defs[1]->Name()); // axes
42+
model_builder.AddInputToSkip(input_defs[1]->Name()); // axes
4043
}
4144
}
4245

@@ -53,71 +56,50 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
5356

5457
NodeAttrHelper helper(node);
5558
const auto keep_dims = helper.Get("keepdims", 1);
59+
5660
emscripten::val options = emscripten::val::object();
5761
options.set("label", node.Name());
5862
options.set("keepDimensions", keep_dims == 1);
59-
std::vector<int32_t> axes_data;
60-
61-
emscripten::val output = emscripten::val::object();
6263

64+
std::vector<int64_t> axes_data;
6365
const auto opset = node.SinceVersion();
6466
const auto& op_type = node.OpType();
6567
if (opset >= 18 || (op_type == "ReduceSum" && opset >= 13)) {
6668
// 'axes' is an optional input.
67-
const auto noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0);
68-
if (!GetTensorName(input_defs, 1).empty()) {
69-
// Optional input axes is provided, use axes initializer data.
70-
const auto& initializers(model_builder.GetInitializerTensors());
71-
const auto& axes_tensor = *initializers.at(input_defs[1]->Name());
72-
Initializer axes_initializer(axes_tensor);
73-
const auto axes_data_span = axes_initializer.DataAsSpan<int64_t>();
74-
std::transform(
75-
axes_data_span.begin(), axes_data_span.end(), std::back_inserter(axes_data),
76-
[input_rank](int64_t axis) -> int32_t { return SafeInt<int32_t>(HandleNegativeAxis(axis, input_rank)); });
77-
} else {
78-
if (noop_with_empty_axes) {
79-
// When axes is empty and this attribute is set to true, input tensor will not be reduced.
80-
output = input;
81-
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
82-
return Status::OK();
69+
std::vector<int64_t> axes_shape;
70+
if (TensorExists(input_defs, 1)) {
71+
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], axes_shape, logger), "Cannot get shape of input axes");
72+
if (axes_shape[0] != 0) {
73+
// Optional input axes is provided and we already ensure it is an initializer.
74+
// Use that initializer data.
75+
const auto& initializers(model_builder.GetInitializerTensors());
76+
const auto& axes_tensor = *initializers.at(input_defs[1]->Name());
77+
Initializer axes_initializer(axes_tensor);
78+
const auto axes_data_span = axes_initializer.DataAsSpan<int64_t>();
79+
axes_data = HandleNegativeAxes(axes_data_span, input_rank);
8380
}
8481
}
8582
} else {
8683
if (helper.HasAttr("axes")) {
87-
auto axes = helper.Get("axes", std::vector<int64_t>{});
88-
std::transform(
89-
axes.begin(), axes.end(), std::back_inserter(axes_data),
90-
[input_rank](int64_t axis) -> int32_t { return SafeInt<int32_t>(HandleNegativeAxis(axis, input_rank)); });
84+
axes_data = GetResolvedAxes(helper, input_rank);
9185
}
9286
}
93-
if (axes_data.size() > 0) {
94-
options.set("axes", emscripten::val::array(axes_data));
95-
}
9687

97-
if (op_type == "ReduceL1") {
98-
output = model_builder.GetBuilder().call<emscripten::val>("reduceL1", input, options);
99-
} else if (op_type == "ReduceL2") {
100-
output = model_builder.GetBuilder().call<emscripten::val>("reduceL2", input, options);
101-
} else if (op_type == "ReduceLogSum") {
102-
output = model_builder.GetBuilder().call<emscripten::val>("reduceLogSum", input, options);
103-
} else if (op_type == "ReduceLogSumExp") {
104-
output = model_builder.GetBuilder().call<emscripten::val>("reduceLogSumExp", input, options);
105-
} else if (op_type == "ReduceMax") {
106-
output = model_builder.GetBuilder().call<emscripten::val>("reduceMax", input, options);
107-
} else if (op_type == "ReduceMean") {
108-
output = model_builder.GetBuilder().call<emscripten::val>("reduceMean", input, options);
109-
} else if (op_type == "ReduceMin") {
110-
output = model_builder.GetBuilder().call<emscripten::val>("reduceMin", input, options);
111-
} else if (op_type == "ReduceProd") {
112-
output = model_builder.GetBuilder().call<emscripten::val>("reduceProduct", input, options);
113-
} else if (op_type == "ReduceSum") {
114-
output = model_builder.GetBuilder().call<emscripten::val>("reduceSum", input, options);
115-
} else if (op_type == "ReduceSumSquare") {
116-
output = model_builder.GetBuilder().call<emscripten::val>("reduceSumSquare", input, options);
117-
} else {
118-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ReductionOpBuilder, unknown op: ", op_type);
88+
// When axes is not provided or is empty, check the 'noop_with_empty_axes' attribute:
89+
// - If it is false, perform reduction over all dimensions.
90+
// (In WebNN, this means the 'axes' option is not set.)
91+
// - If it is true, no reduction is applied, but other operations are still performed.
92+
// (In WebNN, this requires setting 'axes' to an empty array.)
93+
if (!axes_data.empty() || helper.Get("noop_with_empty_axes", 0) == 1) {
94+
options.set("axes", emscripten::val::array(GetNarrowedIntFromInt64<uint32_t>(axes_data)));
11995
}
12096

97+
const std::string_view webnn_op_type = GetWebNNOpType(op_type);
98+
ORT_RETURN_IF(webnn_op_type.empty(), "Cannot get WebNN op type");
99+
100+
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>(
101+
std::string(webnn_op_type).c_str(), input, options);
102+
121103
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
122104
return Status::OK();
123105
}
@@ -128,11 +110,25 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer,
128110
const WebnnDeviceType /* device_type */,
129111
const logging::Logger& logger) const {
130112
const auto& input_defs = node.InputDefs();
131-
const std::string axes_name = GetTensorName(input_defs, 1);
132-
// If the optional input 'axes' is provided, it must be an initializer.
133-
if (!axes_name.empty() && !graph_viewer.GetConstantInitializer(axes_name)) {
134-
LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be a constant";
135-
return false;
113+
114+
if (TensorExists(input_defs, 1)) {
115+
std::vector<int64_t> axes_shape;
116+
if (!GetShape(*input_defs[1], axes_shape, logger)) {
117+
LOGS(logger, VERBOSE) << "Cannot get shape of input axes";
118+
return false;
119+
}
120+
121+
if (axes_shape.size() != 1) {
122+
LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be 1D";
123+
return false;
124+
}
125+
126+
const std::string axes_name = GetTensorName(input_defs, 1);
127+
// If the optional input 'axes' is provided and not empty, it must be an initializer.
128+
if (axes_shape[0] != 0 && !graph_viewer.GetConstantInitializer(axes_name)) {
129+
LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be a constant";
130+
return false;
131+
}
136132
}
137133

138134
return true;

0 commit comments

Comments
 (0)