File tree Expand file tree Collapse file tree 10 files changed +145
-1
lines changed Expand file tree Collapse file tree 10 files changed +145
-1
lines changed Original file line number Diff line number Diff line change 19
19
op_dynamic_dequantize_ops ,
20
20
op_dynamic_quantize_ops ,
21
21
op_elu ,
22
+ op_exp ,
22
23
op_floor ,
23
24
op_gelu ,
24
25
op_hardswish ,
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .operators .node_visitor import (
11
+ NodeVisitor ,
12
+ register_node_visitor ,
13
+ )
14
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15
+ XNNExp ,
16
+ XNNGraph ,
17
+ XNode ,
18
+ )
19
+ from executorch .backends .xnnpack .utils .utils import get_input_node
20
+
21
+
22
+ @register_node_visitor
23
+ class ExpVisitor (NodeVisitor ):
24
+ target = "aten.exp.default"
25
+
26
+ def __init__ (self , * args ) -> None :
27
+ super ().__init__ (* args )
28
+
29
+ def define_node (
30
+ self ,
31
+ node : torch .fx .Node ,
32
+ xnn_graph : XNNGraph ,
33
+ vals_to_ids : Dict [torch .fx .Node , int ],
34
+ debug_handle : int ,
35
+ ) -> None :
36
+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37
+
38
+ # input
39
+ input_id = vals_to_ids [get_input_node (node , 0 )]
40
+
41
+ # output
42
+ output_id = vals_to_ids [node ]
43
+
44
+ ser_node = XNode (
45
+ xnode_union = XNNExp (
46
+ input_id = input_id ,
47
+ output_id = output_id ,
48
+ flags = 0 ,
49
+ ),
50
+ debug_handle = debug_handle ,
51
+ )
52
+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 25
25
ConstantPadConfig ,
26
26
DeQuantizedPerTensorConfig ,
27
27
DivConfig ,
28
+ # EluConfig,
29
+ ExpConfig ,
28
30
FloorConfig ,
29
31
GeluConfig ,
30
32
HardswishConfig ,
31
- # EluConfig,
32
33
HardtanhConfig ,
33
34
LeakyReLUConfig ,
34
35
LogConfig ,
79
80
ClampConfig ,
80
81
DivConfig ,
81
82
# EluConfig, # Waiting for PyTorch Pin Update
83
+ ExpConfig ,
82
84
FloorConfig ,
83
85
GeluConfig ,
84
86
HardtanhConfig ,
Original file line number Diff line number Diff line change @@ -336,6 +336,13 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336
336
return torch .ops .aten .upsample_bilinear2d .vec
337
337
338
338
339
+ class ExpConfig (GenericNodePartitionerConfig ):
340
+ target_name = "exp.default"
341
+
342
+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
343
+ return [ConfigPrecisionType .FP32 ]
344
+
345
+
339
346
class FloorConfig (GenericNodePartitionerConfig ):
340
347
target_name = "floor.default"
341
348
Original file line number Diff line number Diff line change 59
59
exir_ops .edge .aten .sigmoid .default ,
60
60
exir_ops .edge .aten ._softmax .default ,
61
61
exir_ops .edge .aten .cat .default ,
62
+ exir_ops .edge .aten .exp .default ,
62
63
exir_ops .edge .aten .elu .default ,
63
64
exir_ops .edge .aten .avg_pool2d .default ,
64
65
exir_ops .edge .aten .leaky_relu .default ,
Original file line number Diff line number Diff line change @@ -1723,6 +1723,36 @@ Error defineELUNode(
1723
1723
return Error::Ok;
1724
1724
}
1725
1725
1726
+ /*
1727
+ Define serialized exp node into the subgraph, using the remapped ids
1728
+ to map the serialized ids, to the new ids generated when defining the
1729
+ tensor value
1730
+ */
1731
+ Error defineExpNode (
1732
+ xnn_subgraph_t subgraph_ptr,
1733
+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1734
+ const NodePtr node,
1735
+ const fb_xnnpack::XNNGraph* graph) noexcept {
1736
+ MAYBE_UNUSED (graph);
1737
+
1738
+ auto graph_node = node->xnode_union_as_XNNExp ();
1739
+
1740
+ xnn_status status = xnn_define_exp (
1741
+ subgraph_ptr,
1742
+ remapped_ids.at (graph_node->input_id ()),
1743
+ remapped_ids.at (graph_node->output_id ()),
1744
+ graph_node->flags ());
1745
+
1746
+ ET_CHECK_OR_RETURN_ERROR (
1747
+ status == xnn_status_success,
1748
+ Internal,
1749
+ " Failed to create exp node %i with code: %s" ,
1750
+ node->debug_handle (),
1751
+ xnn_status_to_string (status));
1752
+
1753
+ return Error::Ok;
1754
+ }
1755
+
1726
1756
/*
1727
1757
Defines absolute value node into subgraph using the remapped ids to map the
1728
1758
serialized ids to the new ids generated when defining the tensor value
@@ -2082,6 +2112,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
2082
2112
_DEFINE (Negate)
2083
2113
_DEFINE (Square)
2084
2114
_DEFINE (ELU)
2115
+ _DEFINE (Exp)
2085
2116
_DEFINE (Abs)
2086
2117
_DEFINE (PReLU)
2087
2118
_DEFINE (Concatenate2)
Original file line number Diff line number Diff line change @@ -132,6 +132,7 @@ union XNodeUnion {
132
132
XNNNegate: _XNNNode1x1,
133
133
XNNSquare: _XNNNode1x1,
134
134
XNNELU,
135
+ XNNExp: _XNNNode1x1,
135
136
XNNAbs: _XNNNode1x1,
136
137
XNNPReLU: _XNNNode2x1,
137
138
XNNConcatenate2: _XNNCat,
Original file line number Diff line number Diff line change @@ -128,6 +128,7 @@ union XNodeUnion {
128
128
XNNNegate: _XNNNode1x1,
129
129
XNNSquare: _XNNNode1x1,
130
130
XNNELU,
131
+ XNNExp: _XNNNode1x1,
131
132
XNNAbs: _XNNNode1x1,
132
133
XNNPReLU: _XNNNode2x1,
133
134
XNNConcatenate2: _XNNCat,
Original file line number Diff line number Diff line change @@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291
291
pass
292
292
293
293
294
+ @dataclass
295
+ class XNNExp (XNNNode1x1 ):
296
+ pass
297
+
298
+
294
299
@dataclass
295
300
class XNNGelu (XNNNode1x1 ):
296
301
pass
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .test .tester import Tester
11
+
12
+
13
+ class TestExp (unittest .TestCase ):
14
+ def setUp (self ):
15
+ torch ._dynamo .reset ()
16
+
17
+ class Exp (torch .nn .Module ):
18
+ def __init__ (self ):
19
+ super ().__init__ ()
20
+
21
+ def forward (self , x ):
22
+ return torch .exp (x )
23
+
24
+ def run_exp_test (self , inputs ):
25
+ (
26
+ Tester (self .Exp (), inputs )
27
+ .export ()
28
+ .check_count ({"torch.ops.aten.exp.default" : 1 })
29
+ .to_edge_transform_and_lower ()
30
+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
31
+ .check_not (["executorch_exir_dialects_edge__ops_aten_exp_default" ])
32
+ .to_executorch ()
33
+ .serialize ()
34
+ .run_method_and_compare_outputs ()
35
+ )
36
+
37
+ def test_fp16_exp (self ):
38
+ inputs = (torch .randn (20 ).to (torch .float16 ),)
39
+ self .run_exp_test (inputs )
40
+
41
+ def test_fp32_exp (self ):
42
+ inputs = (torch .randn (20 ),)
43
+ self .run_exp_test (inputs )
You can’t perform that action at this time.
0 commit comments