Skip to content

Commit a54870e

Browse files
authored
Support exp op in XNNPACK backend (#11373)
### Summary Support exp in XNNPACK backend ### Test plan Wrote test cases to see if appropriate xnnpack exp was called
1 parent d135ba4 commit a54870e

File tree

10 files changed

+145
-1
lines changed

10 files changed

+145
-1
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
op_dynamic_dequantize_ops,
2020
op_dynamic_quantize_ops,
2121
op_elu,
22+
op_exp,
2223
op_floor,
2324
op_gelu,
2425
op_hardswish,

backends/xnnpack/operators/op_exp.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNExp,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class ExpVisitor(NodeVisitor):
24+
target = "aten.exp.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNExp(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
ConstantPadConfig,
2626
DeQuantizedPerTensorConfig,
2727
DivConfig,
28+
# EluConfig,
29+
ExpConfig,
2830
FloorConfig,
2931
GeluConfig,
3032
HardswishConfig,
31-
# EluConfig,
3233
HardtanhConfig,
3334
LeakyReLUConfig,
3435
LogConfig,
@@ -79,6 +80,7 @@
7980
ClampConfig,
8081
DivConfig,
8182
# EluConfig, # Waiting for PyTorch Pin Update
83+
ExpConfig,
8284
FloorConfig,
8385
GeluConfig,
8486
HardtanhConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,13 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336336
return torch.ops.aten.upsample_bilinear2d.vec
337337

338338

339+
class ExpConfig(GenericNodePartitionerConfig):
340+
target_name = "exp.default"
341+
342+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
343+
return [ConfigPrecisionType.FP32]
344+
345+
339346
class FloorConfig(GenericNodePartitionerConfig):
340347
target_name = "floor.default"
341348

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
exir_ops.edge.aten.sigmoid.default,
6060
exir_ops.edge.aten._softmax.default,
6161
exir_ops.edge.aten.cat.default,
62+
exir_ops.edge.aten.exp.default,
6263
exir_ops.edge.aten.elu.default,
6364
exir_ops.edge.aten.avg_pool2d.default,
6465
exir_ops.edge.aten.leaky_relu.default,

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,36 @@ Error defineELUNode(
17231723
return Error::Ok;
17241724
}
17251725

1726+
/*
1727+
Define serialized exp node into the subgraph, using the remapped ids
1728+
to map the serialized ids, to the new ids generated when defining the
1729+
tensor value
1730+
*/
1731+
Error defineExpNode(
1732+
xnn_subgraph_t subgraph_ptr,
1733+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1734+
const NodePtr node,
1735+
const fb_xnnpack::XNNGraph* graph) noexcept {
1736+
MAYBE_UNUSED(graph);
1737+
1738+
auto graph_node = node->xnode_union_as_XNNExp();
1739+
1740+
xnn_status status = xnn_define_exp(
1741+
subgraph_ptr,
1742+
remapped_ids.at(graph_node->input_id()),
1743+
remapped_ids.at(graph_node->output_id()),
1744+
graph_node->flags());
1745+
1746+
ET_CHECK_OR_RETURN_ERROR(
1747+
status == xnn_status_success,
1748+
Internal,
1749+
"Failed to create exp node %i with code: %s",
1750+
node->debug_handle(),
1751+
xnn_status_to_string(status));
1752+
1753+
return Error::Ok;
1754+
}
1755+
17261756
/*
17271757
Defines absolute value node into subgraph using the remapped ids to map the
17281758
serialized ids to the new ids generated when defining the tensor value
@@ -2082,6 +2112,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
20822112
_DEFINE(Negate)
20832113
_DEFINE(Square)
20842114
_DEFINE(ELU)
2115+
_DEFINE(Exp)
20852116
_DEFINE(Abs)
20862117
_DEFINE(PReLU)
20872118
_DEFINE(Concatenate2)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ union XNodeUnion {
132132
XNNNegate: _XNNNode1x1,
133133
XNNSquare: _XNNNode1x1,
134134
XNNELU,
135+
XNNExp: _XNNNode1x1,
135136
XNNAbs: _XNNNode1x1,
136137
XNNPReLU: _XNNNode2x1,
137138
XNNConcatenate2: _XNNCat,

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ union XNodeUnion {
128128
XNNNegate: _XNNNode1x1,
129129
XNNSquare: _XNNNode1x1,
130130
XNNELU,
131+
XNNExp: _XNNNode1x1,
131132
XNNAbs: _XNNNode1x1,
132133
XNNPReLU: _XNNNode2x1,
133134
XNNConcatenate2: _XNNCat,

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291291
pass
292292

293293

294+
@dataclass
295+
class XNNExp(XNNNode1x1):
296+
pass
297+
298+
294299
@dataclass
295300
class XNNGelu(XNNNode1x1):
296301
pass

backends/xnnpack/test/ops/test_exp.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestExp(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Exp(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, x):
22+
return torch.exp(x)
23+
24+
def run_exp_test(self, inputs):
25+
(
26+
Tester(self.Exp(), inputs)
27+
.export()
28+
.check_count({"torch.ops.aten.exp.default": 1})
29+
.to_edge_transform_and_lower()
30+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
31+
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
32+
.to_executorch()
33+
.serialize()
34+
.run_method_and_compare_outputs()
35+
)
36+
37+
def test_fp16_exp(self):
38+
inputs = (torch.randn(20).to(torch.float16),)
39+
self.run_exp_test(inputs)
40+
41+
def test_fp32_exp(self):
42+
inputs = (torch.randn(20),)
43+
self.run_exp_test(inputs)

0 commit comments

Comments
 (0)