|
9 | 9 | #include "core/common/inlined_containers.h"
|
10 | 10 | #include <core/graph/basic_types.h>
|
11 | 11 | #include "core/optimizer/initializer.h"
|
12 |
| -#include "core/providers/common.h" |
13 | 12 | #include "core/providers/shared/utils/utils.h"
|
| 13 | +#include "map_info.h" |
14 | 14 |
|
15 | 15 | #include <emscripten.h>
|
16 | 16 | #include <emscripten/val.h>
|
@@ -201,183 +201,27 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
|
201 | 201 | const emscripten::val& wnn_limits,
|
202 | 202 | const logging::Logger& logger);
|
203 | 203 |
|
204 |
| -// Some ONNX ops are supported by decomposed WebNN ops. |
205 |
| -const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = { |
206 |
| - {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, |
207 |
| - {"GroupQueryAttention", |
208 |
| - {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", |
209 |
| - "softmax", "transpose", "where"}}, |
210 |
| - {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, |
211 |
| - {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, |
212 |
| - {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, |
213 |
| - {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, |
214 |
| - {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, |
215 |
| - {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, |
216 |
| - {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, |
217 |
| -}; |
218 |
| -// ONNX op type to WebNN op type mapping. |
219 |
| -const std::map<std::string_view, std::string_view> op_map = { |
220 |
| - {"Abs", "abs"}, |
221 |
| - {"Add", "add"}, |
222 |
| - {"And", "logicalAnd"}, |
223 |
| - {"ArgMax", "argMax"}, |
224 |
| - {"ArgMin", "argMin"}, |
225 |
| - {"AveragePool", "averagePool2d"}, |
226 |
| - {"BatchNormalization", "batchNormalization"}, |
227 |
| - {"Cast", "cast"}, |
228 |
| - {"Ceil", "ceil"}, |
229 |
| - {"Clip", "clamp"}, |
230 |
| - {"Concat", "concat"}, |
231 |
| - {"Conv", "conv2d"}, |
232 |
| - {"ConvTranspose", "convTranspose2d"}, |
233 |
| - {"Cos", "cos"}, |
234 |
| - {"CumSum", "cumulativeSum"}, |
235 |
| - {"Div", "div"}, |
236 |
| - {"DequantizeLinear", "dequantizeLinear"}, |
237 |
| - {"Dropout", "identity"}, |
238 |
| - {"DynamicQuantizeLinear", "dynamicQuantizeLinear"}, |
239 |
| - {"Einsum", "matmul"}, |
240 |
| - {"Elu", "elu"}, |
241 |
| - {"Equal", "equal"}, |
242 |
| - {"Erf", "erf"}, |
243 |
| - {"Exp", "exp"}, |
244 |
| - {"Expand", "expand"}, |
245 |
| - {"Flatten", "reshape"}, |
246 |
| - {"Floor", "floor"}, |
247 |
| - {"Gather", "gather"}, |
248 |
| - {"GatherElements", "gatherElements"}, |
249 |
| - {"GatherND", "gatherND"}, |
250 |
| - {"Gelu", "gelu"}, |
251 |
| - {"Gemm", "gemm"}, |
252 |
| - {"GlobalAveragePool", "averagePool2d"}, |
253 |
| - {"GlobalMaxPool", "maxPool2d"}, |
254 |
| - {"GlobalLpPool", "l2Pool2d"}, |
255 |
| - {"Greater", "greater"}, |
256 |
| - {"GreaterOrEqual", "greaterOrEqual"}, |
257 |
| - {"GRU", "gru"}, |
258 |
| - {"HardSigmoid", "hardSigmoid"}, |
259 |
| - {"HardSwish", "hardSwish"}, |
260 |
| - {"Identity", "identity"}, |
261 |
| - {"InstanceNormalization", "instanceNormalization"}, |
262 |
| - {"LayerNormalization", "layerNormalization"}, |
263 |
| - {"LeakyRelu", "leakyRelu"}, |
264 |
| - {"Less", "lesser"}, |
265 |
| - {"LessOrEqual", "lesserOrEqual"}, |
266 |
| - {"Log", "log"}, |
267 |
| - {"LpPool", "l2Pool2d"}, |
268 |
| - {"LSTM", "lstm"}, |
269 |
| - {"MatMul", "matmul"}, |
270 |
| - {"Max", "max"}, |
271 |
| - {"MaxPool", "maxPool2d"}, |
272 |
| - {"Min", "min"}, |
273 |
| - {"Mul", "mul"}, |
274 |
| - {"Neg", "neg"}, |
275 |
| - {"Not", "logicalNot"}, |
276 |
| - {"Or", "logicalOr"}, |
277 |
| - {"Pad", "pad"}, |
278 |
| - {"Pow", "pow"}, |
279 |
| - {"PRelu", "prelu"}, |
280 |
| - {"QuantizeLinear", "quantizeLinear"}, |
281 |
| - {"Reciprocal", "reciprocal"}, |
282 |
| - {"ReduceL1", "reduceL1"}, |
283 |
| - {"ReduceL2", "reduceL2"}, |
284 |
| - {"ReduceLogSum", "reduceLogSum"}, |
285 |
| - {"ReduceLogSumExp", "reduceLogSumExp"}, |
286 |
| - {"ReduceMax", "reduceMax"}, |
287 |
| - {"ReduceMean", "reduceMean"}, |
288 |
| - {"ReduceMin", "reduceMin"}, |
289 |
| - {"ReduceProd", "reduceProduct"}, |
290 |
| - {"ReduceSum", "reduceSum"}, |
291 |
| - {"ReduceSumSquare", "reduceSumSquare"}, |
292 |
| - {"Relu", "relu"}, |
293 |
| - {"Reshape", "reshape"}, |
294 |
| - {"Resize", "resample2d"}, |
295 |
| - {"ScatterElements", "scatterElements"}, |
296 |
| - {"ScatterND", "scatterND"}, |
297 |
| - {"Shape", "slice"}, |
298 |
| - {"Sigmoid", "sigmoid"}, |
299 |
| - {"Sign", "sign"}, |
300 |
| - {"Softplus", "softplus"}, |
301 |
| - {"Softsign", "softsign"}, |
302 |
| - {"Sin", "sin"}, |
303 |
| - {"Slice", "slice"}, |
304 |
| - {"Softmax", "softmax"}, |
305 |
| - {"Split", "split"}, |
306 |
| - {"Sqrt", "sqrt"}, |
307 |
| - {"Squeeze", "reshape"}, |
308 |
| - {"Sub", "sub"}, |
309 |
| - {"Tan", "tan"}, |
310 |
| - {"Tanh", "tanh"}, |
311 |
| - {"Tile", "tile"}, |
312 |
| - {"Transpose", "transpose"}, |
313 |
| - {"Trilu", "triangular"}, |
314 |
| - {"Unsqueeze", "reshape"}, |
315 |
| - {"Where", "where"}, |
316 |
| - {"Xor", "logicalXor"}, |
317 |
| -}; |
318 |
| - |
319 |
| -// WebNN op name to its first input name mapping, only record the name that is different from "input". |
320 |
| -// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits. |
321 |
| -const std::map<std::string_view, std::string_view> webnn_op_first_input_name_map = { |
322 |
| - {"add", "a"}, |
323 |
| - {"concat", "inputs"}, |
324 |
| - {"div", "a"}, |
325 |
| - {"equal", "a"}, |
326 |
| - {"gemm", "a"}, |
327 |
| - {"greater", "a"}, |
328 |
| - {"greaterOrEqual", "a"}, |
329 |
| - {"lesser", "a"}, |
330 |
| - {"lesserOrEqual", "a"}, |
331 |
| - {"logicalAnd", "a"}, |
332 |
| - {"logicalNot", "a"}, |
333 |
| - {"logicalOr", "a"}, |
334 |
| - {"logicalXor", "a"}, |
335 |
| - {"matmul", "a"}, |
336 |
| - {"max", "a"}, |
337 |
| - {"min", "a"}, |
338 |
| - {"mul", "a"}, |
339 |
| - {"pow", "a"}, |
340 |
| - {"sub", "a"}, |
341 |
| - {"where", "condition"}, |
342 |
| -}; |
343 |
| - |
344 | 204 | // Retrieve the first input name of a WebNN op used for validating supported input data types.
|
345 | 205 | // WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
|
346 |
| -// Special names other than 'input' are recorded in the webnn_op_first_input_name_map. |
| 206 | +// All WebNN op inputs are recorded in op_inputs_map. |
347 | 207 | inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) {
|
348 |
| - auto it = webnn_op_first_input_name_map.find(webnn_op_type); |
349 |
| - return (it != webnn_op_first_input_name_map.end()) ? it->second : "input"; |
| 208 | + auto it = op_inputs_map.find(webnn_op_type); |
| 209 | + if (it != op_inputs_map.end()) { |
| 210 | + for (const auto& input : it->second.inputs) { |
| 211 | + if (input.index == 0) { |
| 212 | + return input.name; |
| 213 | + } |
| 214 | + } |
| 215 | + } |
| 216 | + return "input"; |
350 | 217 | }
|
351 | 218 |
|
352 | 219 | inline std::string_view GetWebNNOpType(const std::string_view op_type) {
|
353 |
| - auto it = op_map.find(op_type); |
354 |
| - // Return an empty string if the op_type is not listed in the op_map. |
355 |
| - return (it != op_map.end()) ? it->second : ""; |
| 220 | + auto it = op_inputs_map.find(op_type); |
| 221 | + // Return an empty string if the op_type is not listed in the op_inputs_map. |
| 222 | + return (it != op_inputs_map.end()) ? it->second.opType : ""; |
356 | 223 | }
|
357 | 224 |
|
358 |
| -const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_webnn_data_type_map = { |
359 |
| - {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, |
360 |
| - {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, |
361 |
| - {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, |
362 |
| - {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, |
363 |
| - {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, |
364 |
| - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, |
365 |
| - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, |
366 |
| - {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, |
367 |
| - {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, |
368 |
| - {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, |
369 |
| - {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, |
370 |
| -}; |
371 |
| - |
372 |
| -// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32. |
373 |
| -constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 5> supported_fallback_integer_data_types = { |
374 |
| - ONNX_NAMESPACE::TensorProto_DataType_BOOL, |
375 |
| - ONNX_NAMESPACE::TensorProto_DataType_INT8, |
376 |
| - ONNX_NAMESPACE::TensorProto_DataType_UINT8, |
377 |
| - ONNX_NAMESPACE::TensorProto_DataType_UINT32, |
378 |
| - ONNX_NAMESPACE::TensorProto_DataType_INT64, |
379 |
| -}; |
380 |
| - |
381 | 225 | bool AreDataTypesSame(const std::string_view op_type,
|
382 | 226 | gsl::span<const int32_t> input_types,
|
383 | 227 | const logging::Logger& logger);
|
|
0 commit comments