Skip to content

Commit 5618199

Browse files
NingW101Honry
andauthored
[WebNN] Refactor op mappings and add input name mapping between ONNX and WebNN (#24830)
### Description Add `map_info.h` to centralize the operation types and inputs mapping between onnx and webnn. ### Motivation and Context To simplify the maintenance of operation types and inputs. The mapping of onnx input names and webnn input names will be used in the future to check the `rankRange`. @Honry, @fdwr, @guschmue, PTAL, thanks! --------- Co-authored-by: Wanming Lin <[email protected]>
1 parent 625289c commit 5618199

File tree

2 files changed

+219
-170
lines changed

2 files changed

+219
-170
lines changed

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 14 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#include "core/common/inlined_containers.h"
1010
#include <core/graph/basic_types.h>
1111
#include "core/optimizer/initializer.h"
12-
#include "core/providers/common.h"
1312
#include "core/providers/shared/utils/utils.h"
13+
#include "map_info.h"
1414

1515
#include <emscripten.h>
1616
#include <emscripten/val.h>
@@ -201,183 +201,27 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
201201
const emscripten::val& wnn_limits,
202202
const logging::Logger& logger);
203203

204-
// Some ONNX ops are supported by decomposed WebNN ops.
205-
const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = {
206-
{"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}},
207-
{"GroupQueryAttention",
208-
{"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND",
209-
"softmax", "transpose", "where"}},
210-
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
211-
{"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}},
212-
{"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}},
213-
{"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}},
214-
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}},
215-
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
216-
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
217-
};
218-
// ONNX op type to WebNN op type mapping.
219-
const std::map<std::string_view, std::string_view> op_map = {
220-
{"Abs", "abs"},
221-
{"Add", "add"},
222-
{"And", "logicalAnd"},
223-
{"ArgMax", "argMax"},
224-
{"ArgMin", "argMin"},
225-
{"AveragePool", "averagePool2d"},
226-
{"BatchNormalization", "batchNormalization"},
227-
{"Cast", "cast"},
228-
{"Ceil", "ceil"},
229-
{"Clip", "clamp"},
230-
{"Concat", "concat"},
231-
{"Conv", "conv2d"},
232-
{"ConvTranspose", "convTranspose2d"},
233-
{"Cos", "cos"},
234-
{"CumSum", "cumulativeSum"},
235-
{"Div", "div"},
236-
{"DequantizeLinear", "dequantizeLinear"},
237-
{"Dropout", "identity"},
238-
{"DynamicQuantizeLinear", "dynamicQuantizeLinear"},
239-
{"Einsum", "matmul"},
240-
{"Elu", "elu"},
241-
{"Equal", "equal"},
242-
{"Erf", "erf"},
243-
{"Exp", "exp"},
244-
{"Expand", "expand"},
245-
{"Flatten", "reshape"},
246-
{"Floor", "floor"},
247-
{"Gather", "gather"},
248-
{"GatherElements", "gatherElements"},
249-
{"GatherND", "gatherND"},
250-
{"Gelu", "gelu"},
251-
{"Gemm", "gemm"},
252-
{"GlobalAveragePool", "averagePool2d"},
253-
{"GlobalMaxPool", "maxPool2d"},
254-
{"GlobalLpPool", "l2Pool2d"},
255-
{"Greater", "greater"},
256-
{"GreaterOrEqual", "greaterOrEqual"},
257-
{"GRU", "gru"},
258-
{"HardSigmoid", "hardSigmoid"},
259-
{"HardSwish", "hardSwish"},
260-
{"Identity", "identity"},
261-
{"InstanceNormalization", "instanceNormalization"},
262-
{"LayerNormalization", "layerNormalization"},
263-
{"LeakyRelu", "leakyRelu"},
264-
{"Less", "lesser"},
265-
{"LessOrEqual", "lesserOrEqual"},
266-
{"Log", "log"},
267-
{"LpPool", "l2Pool2d"},
268-
{"LSTM", "lstm"},
269-
{"MatMul", "matmul"},
270-
{"Max", "max"},
271-
{"MaxPool", "maxPool2d"},
272-
{"Min", "min"},
273-
{"Mul", "mul"},
274-
{"Neg", "neg"},
275-
{"Not", "logicalNot"},
276-
{"Or", "logicalOr"},
277-
{"Pad", "pad"},
278-
{"Pow", "pow"},
279-
{"PRelu", "prelu"},
280-
{"QuantizeLinear", "quantizeLinear"},
281-
{"Reciprocal", "reciprocal"},
282-
{"ReduceL1", "reduceL1"},
283-
{"ReduceL2", "reduceL2"},
284-
{"ReduceLogSum", "reduceLogSum"},
285-
{"ReduceLogSumExp", "reduceLogSumExp"},
286-
{"ReduceMax", "reduceMax"},
287-
{"ReduceMean", "reduceMean"},
288-
{"ReduceMin", "reduceMin"},
289-
{"ReduceProd", "reduceProduct"},
290-
{"ReduceSum", "reduceSum"},
291-
{"ReduceSumSquare", "reduceSumSquare"},
292-
{"Relu", "relu"},
293-
{"Reshape", "reshape"},
294-
{"Resize", "resample2d"},
295-
{"ScatterElements", "scatterElements"},
296-
{"ScatterND", "scatterND"},
297-
{"Shape", "slice"},
298-
{"Sigmoid", "sigmoid"},
299-
{"Sign", "sign"},
300-
{"Softplus", "softplus"},
301-
{"Softsign", "softsign"},
302-
{"Sin", "sin"},
303-
{"Slice", "slice"},
304-
{"Softmax", "softmax"},
305-
{"Split", "split"},
306-
{"Sqrt", "sqrt"},
307-
{"Squeeze", "reshape"},
308-
{"Sub", "sub"},
309-
{"Tan", "tan"},
310-
{"Tanh", "tanh"},
311-
{"Tile", "tile"},
312-
{"Transpose", "transpose"},
313-
{"Trilu", "triangular"},
314-
{"Unsqueeze", "reshape"},
315-
{"Where", "where"},
316-
{"Xor", "logicalXor"},
317-
};
318-
319-
// WebNN op name to its first input name mapping, only record the name that is different from "input".
320-
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
321-
const std::map<std::string_view, std::string_view> webnn_op_first_input_name_map = {
322-
{"add", "a"},
323-
{"concat", "inputs"},
324-
{"div", "a"},
325-
{"equal", "a"},
326-
{"gemm", "a"},
327-
{"greater", "a"},
328-
{"greaterOrEqual", "a"},
329-
{"lesser", "a"},
330-
{"lesserOrEqual", "a"},
331-
{"logicalAnd", "a"},
332-
{"logicalNot", "a"},
333-
{"logicalOr", "a"},
334-
{"logicalXor", "a"},
335-
{"matmul", "a"},
336-
{"max", "a"},
337-
{"min", "a"},
338-
{"mul", "a"},
339-
{"pow", "a"},
340-
{"sub", "a"},
341-
{"where", "condition"},
342-
};
343-
344204
// Retrieve the first input name of a WebNN op used for validating supported input data types.
345205
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
346-
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
206+
// All WebNN op inputs are recorded in op_inputs_map.
347207
inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) {
348-
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
349-
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
208+
auto it = op_inputs_map.find(webnn_op_type);
209+
if (it != op_inputs_map.end()) {
210+
for (const auto& input : it->second.inputs) {
211+
if (input.index == 0) {
212+
return input.name;
213+
}
214+
}
215+
}
216+
return "input";
350217
}
351218

352219
inline std::string_view GetWebNNOpType(const std::string_view op_type) {
353-
auto it = op_map.find(op_type);
354-
// Return an empty string if the op_type is not listed in the op_map.
355-
return (it != op_map.end()) ? it->second : "";
220+
auto it = op_inputs_map.find(op_type);
221+
// Return an empty string if the op_type is not listed in the op_inputs_map.
222+
return (it != op_inputs_map.end()) ? it->second.opType : "";
356223
}
357224

358-
const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_webnn_data_type_map = {
359-
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
360-
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
361-
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
362-
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
363-
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
364-
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"},
365-
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"},
366-
{ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"},
367-
{ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"},
368-
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"},
369-
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
370-
};
371-
372-
// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32.
373-
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 5> supported_fallback_integer_data_types = {
374-
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
375-
ONNX_NAMESPACE::TensorProto_DataType_INT8,
376-
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
377-
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
378-
ONNX_NAMESPACE::TensorProto_DataType_INT64,
379-
};
380-
381225
bool AreDataTypesSame(const std::string_view op_type,
382226
gsl::span<const int32_t> input_types,
383227
const logging::Logger& logger);

0 commit comments

Comments
 (0)