Skip to content

Commit 2798b6c

Browse files
authored
Arm backend: Add cosh decomposition pass and test (#13181)
Add decomposition and tests for cosh. Signed-off-by: Emma Kujala <[email protected]>
1 parent be221c6 commit 2798b6c

File tree

7 files changed

+161
-0
lines changed

7 files changed

+161
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
3232
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
3333
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
34+
from .decompose_cosh_pass import DecomposeCoshPass # noqa
3435
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3536
from .decompose_div_pass import DecomposeDivPass # noqa
3637
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DecomposeAtanPass,
3737
DecomposeAvgPool2d,
3838
DecomposeBatchNormNoStatsPass,
39+
DecomposeCoshPass,
3940
DecomposeCosineSimilarityPass,
4041
DecomposeDivPass,
4142
DecomposeEmbeddingPass,
@@ -167,6 +168,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
167168
self.add_pass(DecomposeAcoshPass())
168169
self.add_pass(DecomposeAsinPass())
169170
self.add_pass(DecomposeAsinhPass())
171+
self.add_pass(DecomposeCoshPass())
170172
self.add_pass(DecomposeSqrtPass())
171173
self.add_pass(DecomposeAtanPass())
172174
self.add_pass(DecomposeAtanhPass())
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
9+
# For MI case
10+
edge_cosh = exir_ops.edge.aten.cosh.default
11+
12+
13+
class DecomposeCoshPass(ArmPass):
14+
"""
15+
This pass replaces the cosh operator with a sequence of TOSA-equivalent operations that
16+
compute the hyperbolic cosine using the formula:
17+
18+
cosh(x) = 0.5 * (e^x + e^(-x))
19+
20+
"""
21+
22+
def call_operator(self, op, args, kwargs, meta, updated=False):
23+
if op is not edge_cosh:
24+
return super().call_operator(op, args, kwargs, meta, updated)
25+
26+
x = args
27+
28+
exp_op, mul_op, neg_op, add_op = (
29+
exir_ops.edge.aten.exp.default,
30+
exir_ops.edge.aten.mul.Scalar,
31+
exir_ops.edge.aten.neg.default,
32+
exir_ops.edge.aten.add.Tensor,
33+
)
34+
35+
# exp1 = e^x
36+
exp1 = super().call_operator(exp_op, x, {}, meta, updated=True)
37+
38+
# exp2 = e^(⁻x)
39+
neg_x = super().call_operator(neg_op, x, {}, meta, updated=True)
40+
exp2 = super().call_operator(exp_op, (neg_x,), {}, meta, updated=True)
41+
42+
# numer = exp1 + exp2
43+
numer = super().call_operator(add_op, (exp1, exp2), {}, meta, updated=True)
44+
45+
# out = 0.5 * numer
46+
out = super().call_operator(mul_op, (numer, 0.5), {}, meta, updated=True)
47+
48+
return out

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TableOps:
5959
exir_ops.edge.aten.acosh.default: torch.acosh,
6060
exir_ops.edge.aten.asin.default: torch.asin,
6161
exir_ops.edge.aten.asinh.default: torch.asinh,
62+
exir_ops.edge.aten.cosh.default: torch.cosh,
6263
}
6364

6465
# Targets that must be treated explicitly

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def is_node_supported(
257257
exir_ops.edge.aten.addmm.default,
258258
exir_ops.edge.aten.masked_fill.Scalar,
259259
exir_ops.edge.aten.asinh.default,
260+
exir_ops.edge.aten.cosh.default,
260261
]
261262

262263
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def _match_pattern(
287287
torch.ops.aten.asin.default,
288288
torch.ops.aten.atanh.default,
289289
torch.ops.aten.asinh.default,
290+
torch.ops.aten.cosh.default,
290291
]
291292

292293
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_cosh.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Tuple
6+
7+
import torch
8+
from executorch.backends.arm.test import common
9+
from executorch.backends.arm.test.tester.test_pipeline import (
10+
EthosU55PipelineINT,
11+
EthosU85PipelineINT,
12+
TosaPipelineFP,
13+
TosaPipelineINT,
14+
VgfPipeline,
15+
)
16+
17+
aten_op = "torch.ops.aten.cosh.default"
18+
exir_op = "executorch_exir_dialects_edge__ops_aten__cosh_default"
19+
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
22+
test_data_suite = {
23+
# (test_name, test_data)
24+
"zeros": torch.zeros(10, 10, 10),
25+
"zeros_4D": torch.zeros(1, 10, 32, 7),
26+
"zeros_alt_shape": torch.zeros(10, 3, 5),
27+
"ones": torch.ones(15, 10, 7),
28+
"ones_4D": torch.ones(1, 3, 32, 16),
29+
"rand": torch.rand(10, 10) - 0.5,
30+
"rand_alt_shape": torch.rand(10, 3, 5) - 0.5,
31+
"rand_4D": torch.rand(1, 6, 5, 7) - 0.5,
32+
"randn_pos": torch.randn(10) + 10,
33+
"randn_neg": torch.randn(10) - 10,
34+
"ramp": torch.arange(-16, 16, 0.2),
35+
"large": 100 * torch.ones(1, 1),
36+
"small": 0.000001 * torch.ones(1, 1),
37+
"small_rand": torch.rand(100) * 0.01,
38+
"biggest": torch.tensor([700.0, 710.0, 750.0]),
39+
}
40+
41+
42+
class Cosh(torch.nn.Module):
43+
def forward(self, x: torch.Tensor):
44+
return torch.cosh(x)
45+
46+
47+
@common.parametrize("test_data", test_data_suite)
48+
def test_cosh_tosa_FP(test_data: Tuple):
49+
pipeline = TosaPipelineFP[input_t1](
50+
Cosh(),
51+
(test_data,),
52+
aten_op,
53+
exir_op,
54+
)
55+
pipeline.run()
56+
57+
58+
@common.parametrize("test_data", test_data_suite)
59+
def test_cosh_tosa_INT(test_data: Tuple):
60+
pipeline = TosaPipelineINT[input_t1](
61+
Cosh(), (test_data,), aten_op=aten_op, exir_op=exir_op
62+
)
63+
pipeline.run()
64+
65+
66+
@common.XfailIfNoCorstone300
67+
@common.parametrize("test_data", test_data_suite)
68+
def test_cosh_u55_INT(test_data: Tuple):
69+
pipeline = EthosU55PipelineINT[input_t1](
70+
Cosh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op
71+
)
72+
pipeline.run()
73+
74+
75+
@common.XfailIfNoCorstone320
76+
@common.parametrize("test_data", test_data_suite)
77+
def test_cosh_u85_INT(test_data: Tuple):
78+
pipeline = EthosU85PipelineINT[input_t1](
79+
Cosh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op
80+
)
81+
pipeline.run()
82+
83+
84+
@common.parametrize("test_data", test_data_suite)
85+
@common.SkipIfNoModelConverter
86+
def test_cosh_vgf_FP(test_data: Tuple):
87+
pipeline = VgfPipeline[input_t1](
88+
Cosh(),
89+
(test_data,),
90+
[],
91+
[],
92+
tosa_version="TOSA-1.0+FP",
93+
)
94+
pipeline.run()
95+
96+
97+
@common.parametrize("test_data", test_data_suite)
98+
@common.SkipIfNoModelConverter
99+
def test_cosh_vgf_INT(test_data: Tuple):
100+
pipeline = VgfPipeline[input_t1](
101+
Cosh(),
102+
(test_data,),
103+
[],
104+
[],
105+
tosa_version="TOSA-1.0+INT",
106+
)
107+
pipeline.run()

0 commit comments

Comments
 (0)