Skip to content

Commit d836239

Browse files
Merge branch 'main' into add-cosh-decomposition
2 parents 0cf7b75 + b114f9c commit d836239

16 files changed

+267
-172
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
[submodule "backends/arm/third-party/ethos-u-core-driver"]
22
path = backends/arm/third-party/ethos-u-core-driver
33
url = https://git.gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-core-driver.git
4+
[submodule "backends/arm/third-party/serialization_lib"]
5+
path = backends/arm/third-party/serialization_lib
6+
url = https://git.gitlab.arm.com/tosa/tosa-serialization.git
47
[submodule "backends/vulkan/third-party/Vulkan-Headers"]
58
path = backends/vulkan/third-party/Vulkan-Headers
69
url = https://github.com/KhronosGroup/Vulkan-Headers

backends/arm/_passes/decompose_grouped_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy
77

88
import torch
9-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass
1212

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
get_param_tensor,
1616
is_param_node,
1717
)
18-
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1918

20-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
19+
from executorch.backends.arm._passes.quant_args import QuantArgs
20+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2121

2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.dialects.edge._ops import EdgeOpOverload

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
910
from executorch.backends.arm.constants import Q_OPS
10-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass, PassResult
1313
from torch.fx import Node

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.quant_args import QuantArgs
1213
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch import Tensor
1616
from torch.fx import GraphModule, Node

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
13+
from executorch.backends.arm._passes.quant_args import QuantArgs
1414
from executorch.exir import ExportedProgram
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops

backends/arm/_passes/quant_args.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Any, cast, NamedTuple
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
exir_ops = cast(Any, exir_ops)
11+
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS
12+
from torch import Tensor
13+
14+
15+
class QuantArgs(NamedTuple):
16+
scale: list[float] | float
17+
zp: list[int] | int
18+
qmin: int
19+
qmax: int
20+
dtype: torch.dtype
21+
axis: int = 0
22+
per_channel: bool = False
23+
24+
def quantize_value(self, x: torch.Tensor | float) -> Tensor:
25+
"""Quantizes the input tensor or value to a quantized tensor. If the input is
26+
not a tensor, it is converted to a tensor first. If self.per_channel is True,
27+
the quantization is done per channel, otherwise it is done per tensor.
28+
"""
29+
if not isinstance(x, torch.Tensor):
30+
x = torch.Tensor([x])
31+
x = x.to(torch.float32)
32+
if self.per_channel:
33+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default
34+
args = (
35+
x,
36+
torch.tensor(self.scale),
37+
torch.tensor(self.zp),
38+
self.axis,
39+
self.qmin,
40+
self.qmax,
41+
self.dtype,
42+
)
43+
else:
44+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
45+
args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
46+
return q_op(*args)
47+
48+
def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
49+
"""Dequantizes the input tensor or value to a dequantized tensor If the input
50+
is not a tensor, it is converted to a tensor first. If self.per_channel is True,
51+
the dequantization is done per channel, otherwise it is done per tensor.
52+
"""
53+
if self.per_channel:
54+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
55+
args = (
56+
qx,
57+
torch.tensor(self.scale),
58+
torch.tensor(self.zp),
59+
self.axis,
60+
self.qmin,
61+
self.qmax,
62+
self.dtype,
63+
)
64+
else:
65+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
66+
args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
67+
return dq_op(*args)
68+
69+
@classmethod
70+
def from_operator(cls, op, args):
71+
if op in PER_TENSOR_QDQ_OPS:
72+
return cls(
73+
scale=cast(float, args[1]),
74+
zp=cast(int, args[2]),
75+
qmin=cast(int, args[3]),
76+
qmax=cast(int, args[4]),
77+
dtype=cast(torch.dtype, args[5]),
78+
axis=0,
79+
per_channel=False,
80+
)
81+
elif op in PER_CHANNEL_QDQ_OPS:
82+
return cls(
83+
scale=cast(list[float], args[1].tolist()),
84+
zp=cast(list[int], args[2].tolist()),
85+
axis=cast(int, args[3]),
86+
qmin=cast(int, args[4]),
87+
qmax=cast(int, args[5]),
88+
dtype=cast(torch.dtype, args[6]),
89+
per_channel=True,
90+
)
91+
else:
92+
# We're only handling per tensor and per channel quantization
93+
raise NotImplementedError(f"Unsupported quantization operation: {op}")
94+
95+
def get_scale_per_tensor(self) -> float:
96+
if not isinstance(self.scale, float):
97+
raise TypeError(
98+
f"Expected scale {self.scale} to be a float but found scale of "
99+
f"type {type(self.scale)}"
100+
)
101+
return self.scale
102+
103+
def get_zp_per_tensor(self) -> int:
104+
if not isinstance(self.zp, int):
105+
raise TypeError(
106+
f"Expected zero point {self.zp} to be an int but found zp of "
107+
f"type {type(self.zp)}"
108+
)
109+
return self.zp
110+
111+
def get_scale_per_channel(self) -> list[float]:
112+
if not isinstance(self.scale, list):
113+
raise TypeError(
114+
f"Expected scale {self.scale} to be a list but found scale of "
115+
f"type {type(self.scale)}"
116+
)
117+
return self.scale
118+
119+
def get_zp_per_channel(self) -> list[int]:
120+
if not isinstance(self.zp, list):
121+
raise TypeError(
122+
f"Expected zero point {self.zp} to be a list but found zp of "
123+
f"type {type(self.zp)}"
124+
)
125+
return self.zp

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ class EthosU55NotSupported(OperatorSupportBase):
149149
exir_ops.edge.aten.ne.Scalar,
150150
exir_ops.edge.aten.flip.default, # REVERSE
151151
exir_ops.edge.aten.grid_sampler_2d, # GATHER
152+
exir_ops.edge.aten.index.Tensor, # GATHER
153+
exir_ops.edge.aten.index_select.default, # GATHER
152154
exir_ops.edge.aten.scatter.src,
153155
exir_ops.edge.aten.scatter.value,
154156
exir_ops.edge.aten.select_scatter.default,

backends/arm/test/ops/test_index_select.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from executorch.backends.arm.test import common
1414
from executorch.backends.arm.test.tester.test_pipeline import (
15+
OpNotSupportedPipeline,
1516
TosaPipelineFP,
1617
TosaPipelineINT,
1718
VgfPipeline,
@@ -120,6 +121,20 @@ def test_index_select_tosa_INT_rand(test_data: input_params):
120121
pipeline.run()
121122

122123

124+
@pytest.mark.parametrize("test_data", list(test_data.values())[-1:])
125+
def test_index_select_u55_INT_not_delegated(test_data: input_params):
126+
op, test_input = test_data
127+
128+
pipeline = OpNotSupportedPipeline[input_params](
129+
op,
130+
test_input,
131+
{op.exir_op: 1},
132+
quantize=True,
133+
u55_subset=True,
134+
)
135+
pipeline.run()
136+
137+
123138
@pytest.mark.parametrize("test_data", list(test_data.values()))
124139
@common.SkipIfNoModelConverter
125140
def test_index_select_vgf_FP(test_data: input_params):

backends/arm/test/ops/test_index_tensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from executorch.backends.arm.test import common
1212
from executorch.backends.arm.test.tester.test_pipeline import (
13+
OpNotSupportedPipeline,
1314
TosaPipelineFP,
1415
TosaPipelineINT,
1516
)
@@ -460,3 +461,18 @@ def test_index_tensor_tosa_INT_none(test_data: input_params):
460461
IndexTensorTestCommon.exir_op,
461462
).run()
462463
)
464+
465+
466+
@common.parametrize("test_data", IndexTensor.test_data)
467+
@common.XfailIfNoCorstone300
468+
def test_index_tensor_u55_INT_not_delegated(test_data: input_params):
469+
"""Ethos-U55 backend BI pipeline test for index.Tensor"""
470+
test_input = test_data
471+
with torch.no_grad():
472+
OpNotSupportedPipeline[input_params](
473+
IndexTensor(),
474+
test_input,
475+
{IndexTensorTestCommon.exir_op: 1},
476+
quantize=True,
477+
u55_subset=True,
478+
).run()

0 commit comments

Comments
 (0)