Skip to content

Commit ec379da

Browse files
authored
Arm backend: Add atanh decomposition pass (#12483)
Decomposes atanh into other operators/lookup table for MI/BI case. Signed-off-by: Teo Bergkvist <[email protected]>
1 parent c2d6f3d commit ec379da

File tree

7 files changed

+153
-0
lines changed

7 files changed

+153
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
2727
from .decompose_asin_pass import DecomposeAsinPass # noqa
2828
from .decompose_atan_pass import DecomposeAtanPass # noqa
29+
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
2930
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
3031
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
3132
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DecomposeAcoshPass,
3131
DecomposeAdaptiveAvgPool2dPass,
3232
DecomposeAsinPass,
33+
DecomposeAtanhPass,
3334
DecomposeAtanPass,
3435
DecomposeAvgPool2d,
3536
DecomposeBatchNormNoStatsPass,
@@ -163,6 +164,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
163164
self.add_pass(DecomposeAsinPass())
164165
self.add_pass(DecomposeSqrtPass())
165166
self.add_pass(DecomposeAtanPass())
167+
self.add_pass(DecomposeAtanhPass())
166168
self.add_pass(ConvertIntPowToMuls())
167169
self.add_pass(CastBoolToInt8Pass())
168170
self.add_pass(DecomposeSinhPass())
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
9+
10+
edge_atanh = exir_ops.edge.aten.atanh.default # MI case
11+
12+
13+
def _get_atanh_ops(op):
14+
"""Return the primitive ops required.."""
15+
if op is not edge_atanh:
16+
raise RuntimeError(f"Can't decompose atanh for op {op}")
17+
return (
18+
exir_ops.edge.aten.mul.Tensor,
19+
exir_ops.edge.aten.mul.Scalar,
20+
exir_ops.edge.aten.add.Scalar,
21+
exir_ops.edge.aten.reciprocal.default,
22+
exir_ops.edge.aten.log.default,
23+
exir_ops.edge.aten.neg.default,
24+
)
25+
26+
27+
class DecomposeAtanhPass(ArmPass):
28+
"""
29+
Decomposes the atanh operator into primitive ops.
30+
atanh(x) = 0.5 * log((1 + x) / (1 - x))
31+
"""
32+
33+
def call_operator(self, op, args, kwargs, meta):
34+
if op is not edge_atanh:
35+
return super().call_operator(op, args, kwargs, meta, updated=False)
36+
37+
ops = _get_atanh_ops(op)
38+
(
39+
op_mul_tensor,
40+
op_mul_scalar,
41+
op_add_scalar,
42+
op_reciprocal,
43+
op_log,
44+
op_neg,
45+
) = ops
46+
47+
x = args[0]
48+
49+
nom = super().call_operator(op_add_scalar, (x, 1.0), {}, meta, updated=True)
50+
51+
neg_x = super().call_operator(op_neg, (x,), {}, meta, updated=True)
52+
denom = super().call_operator(
53+
op_add_scalar, (neg_x, 1.0), {}, meta, updated=True
54+
)
55+
recip = super().call_operator(op_reciprocal, (denom,), {}, meta, updated=True)
56+
57+
log_input = super().call_operator(
58+
op_mul_tensor, (nom, recip), {}, meta, updated=True
59+
)
60+
log = super().call_operator(op_log, (log_input,), {}, meta, updated=True)
61+
62+
return super().call_operator(op_mul_scalar, (log, 0.5), {}, meta, updated=True)

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class TableOps:
5252
exir_ops.edge.aten.sin.default: torch.sin,
5353
exir_ops.edge.aten.tanh.default: torch.tanh,
5454
exir_ops.edge.aten.atan.default: torch.atan,
55+
exir_ops.edge.aten.atanh.default: torch.atanh,
5556
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5657
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5758
exir_ops.edge.aten.sinh.default: torch.sinh,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def is_node_supported(
252252
exir_ops.edge.aten._adaptive_avg_pool2d.default,
253253
exir_ops.edge.aten.sign.default,
254254
exir_ops.edge.aten.asin.default,
255+
exir_ops.edge.aten.atanh.default,
255256
]
256257

257258
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def _match_pattern(
218218
torch.ops.aten.acosh.default,
219219
torch.ops.aten.sign.default,
220220
torch.ops.aten.asin.default,
221+
torch.ops.aten.atanh.default,
221222
]
222223

223224
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_atanh.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
aten_op = "torch.ops.aten.atanh.default"
19+
exir_op = "executorch_exir_dialects_edge__ops_aten__atanh_default"
20+
21+
22+
input_t1 = Tuple[torch.Tensor]
23+
24+
25+
test_data_suite = {
26+
"zeros": torch.zeros(1, 10, 10, 10),
27+
"zeros_alt_shape": torch.zeros(1, 10, 3, 5),
28+
"ones": torch.ones(10, 10, 10),
29+
"rand": torch.rand(10, 10) - 0.5,
30+
"rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5,
31+
"ramp": torch.arange(-1, 1, 0.2),
32+
"near_bounds": torch.tensor([-0.999999, -0.999, -0.9, 0.9, 0.999, 0.999999]),
33+
"on_bounds": torch.tensor([-1.0, 1.0]),
34+
}
35+
36+
37+
class Atanh(torch.nn.Module):
38+
def forward(self, x: torch.Tensor):
39+
return torch.atanh(x)
40+
41+
42+
@common.parametrize("test_data", test_data_suite)
43+
def test_atanh_tosa_MI(test_data: Tuple):
44+
pipeline = TosaPipelineMI[input_t1](
45+
Atanh(),
46+
(test_data,),
47+
aten_op=aten_op,
48+
exir_op=exir_op,
49+
)
50+
pipeline.run()
51+
52+
53+
@common.parametrize("test_data", test_data_suite)
54+
def test_atanh_tosa_BI(test_data: Tuple):
55+
pipeline = TosaPipelineBI[input_t1](
56+
Atanh(),
57+
(test_data,),
58+
aten_op=aten_op,
59+
exir_op=exir_op,
60+
)
61+
pipeline.run()
62+
63+
64+
@common.XfailIfNoCorstone300
65+
@common.parametrize("test_data", test_data_suite)
66+
def test_atanh_u55_BI(test_data: Tuple):
67+
pipeline = EthosU55PipelineBI[input_t1](
68+
Atanh(),
69+
(test_data,),
70+
aten_ops=aten_op,
71+
exir_ops=exir_op,
72+
)
73+
pipeline.run()
74+
75+
76+
@common.XfailIfNoCorstone320
77+
@common.parametrize("test_data", test_data_suite)
78+
def test_atanh_u85_BI(test_data: Tuple):
79+
pipeline = EthosU85PipelineBI[input_t1](
80+
Atanh(),
81+
(test_data,),
82+
aten_ops=aten_op,
83+
exir_ops=exir_op,
84+
)
85+
pipeline.run()

0 commit comments

Comments
 (0)