Skip to content

Commit a46a9c8

Browse files
Arm backend: Remove instance checks for Tosa_1.00
With tosa 0.80 removed there's no need to check which tosa version's being used. Change-Id: I8c0d5c7283a756a0a0374dfa73f812a75b9af177 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 564b6e0 commit a46a9c8

File tree

5 files changed

+61
-101
lines changed

5 files changed

+61
-101
lines changed

backends/arm/process_node.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from typing import Any, cast, Dict
99

1010
import numpy as np
11+
import serializer.tosa_serializer as ts
1112
import torch
1213
import torch.fx
1314
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1415
from executorch.backends.arm.tosa_mapping import TosaArg
15-
from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
1617
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
1718
from torch._export.utils import (
1819
get_buffer,
@@ -81,11 +82,6 @@ def process_inputs(
8182
"Is the original torch function supported?"
8283
) from e
8384

84-
if isinstance(tosa_spec, Tosa_1_00):
85-
import serializer.tosa_serializer as ts
86-
else:
87-
raise ValueError(f"Unsupported TOSA spec: {tosa_spec}")
88-
8985
input_shape = tosa_arg.shape
9086
input_dim_order = tosa_arg.dim_order
9187
tensor = ts.TosaSerializerTensor(

backends/arm/tosa_backend.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import logging
1414
from typing import cast, final, List
1515

16-
import executorch.backends.arm.tosa_specification as tosa_specification
17-
16+
import serializer.tosa_serializer as ts # type: ignore
1817
from executorch.backends.arm.arm_backend import get_tosa_spec
1918
from executorch.backends.arm.operators.node_visitor import get_node_visitors
2019
from executorch.backends.arm._passes import (
@@ -85,13 +84,6 @@ def preprocess( # noqa: C901
8584

8685
# Converted output for this subgraph, serializer needs path early as it emits
8786
# const data directly. Path created and data written only in debug builds.
88-
if isinstance(tosa_spec, tosa_specification.Tosa_1_00):
89-
import serializer.tosa_serializer as ts # type: ignore
90-
else:
91-
raise RuntimeError(
92-
f"Unknown TOSA version {tosa_spec}, no pip package installed to handle serialization to that version."
93-
)
94-
9587
tosa_graph = ts.TosaSerializer(artifact_path)
9688

9789
assert (

backends/arm/tosa_mapping.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313

1414
from typing import Any, Optional, Sequence
1515

16+
import serializer.tosa_serializer as ts # type: ignore
17+
1618
import torch
17-
from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification
19+
from executorch.backends.arm.tosa_specification import TosaSpecification
1820

1921
UNSUPPORTED_DTYPES = (
2022
torch.float64,
@@ -32,10 +34,6 @@
3234
def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
3335
if data_type in UNSUPPORTED_DTYPES:
3436
raise ValueError(f"Unsupported type: {data_type}")
35-
if isinstance(tosa_spec, Tosa_1_00):
36-
import serializer.tosa_serializer as ts # type: ignore
37-
else:
38-
raise RuntimeError(f"Unsupported tosa_spec: {tosa_spec}")
3937

4038
dtype_map = {
4139
torch.float32: ts.DType.FP32,
@@ -134,10 +132,6 @@ def __repr__(self):
134132
if self.name is not None:
135133
attrs.append(f"name={self.name!r}")
136134
if self.dtype is not None:
137-
if isinstance(self.tosa_spec, Tosa_1_00):
138-
import serializer.tosa_serializer as ts # type: ignore
139-
else:
140-
raise RuntimeError(f"Unsupported tosa_spec: {self.tosa_spec}")
141135
attrs.append(f"dtype={ts.DTypeNames[self.dtype]}")
142136
if self.shape is not None:
143137
attrs.append(f"shape={self.shape!r}")

backends/arm/tosa_quant_utils.py

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
from typing import Any, Tuple
1313

14-
import executorch.backends.arm.tosa_specification as tosa_specification
15-
14+
import serializer.tosa_serializer as ts # type: ignore
1615
import torch.fx
1716
import torch.fx.node
1817

@@ -247,25 +246,18 @@ def build_rescale_to_int32(
247246
) -> Any:
248247
input_A_rescaled_to_int32 = None
249248

250-
if isinstance(tosa_spec, tosa_specification.Tosa_1_00):
251-
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
252-
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
253-
import serializer.tosa_serializer as ts # type: ignore
254-
255-
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(
256-
input_arg.shape, ts.DType.INT32
257-
)
249+
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32)
258250

259-
build_rescale(
260-
tosa_fb,
261-
[rescale_scale],
262-
input_arg,
263-
input_A_rescaled_to_int32.name,
264-
ts.DType.INT32,
265-
[input_zp],
266-
[0],
267-
rounding_mode=RoundingMode.SINGLE_ROUND,
268-
) # type: ignore[call-arg]
251+
build_rescale(
252+
tosa_fb,
253+
[rescale_scale],
254+
input_arg,
255+
input_A_rescaled_to_int32.name,
256+
ts.DType.INT32,
257+
[input_zp],
258+
[0],
259+
rounding_mode=RoundingMode.SINGLE_ROUND,
260+
) # type: ignore[call-arg]
269261

270262
return input_A_rescaled_to_int32
271263

@@ -281,21 +273,19 @@ def build_rescale_from_int32(
281273
per_channel: bool = False,
282274
tosa_spec=None,
283275
) -> None:
284-
if isinstance(tosa_spec, tosa_specification.Tosa_1_00):
285-
import serializer.tosa_serializer as ts # type: ignore
286-
287-
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
288-
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
289-
build_rescale(
290-
tosa_fb,
291-
[rescale_scale],
292-
input_node,
293-
output_name=output_name,
294-
output_type=ts.DType.INT8,
295-
input_zp=[0],
296-
output_zp=[output_zp],
297-
rounding_mode=RoundingMode.SINGLE_ROUND,
298-
) # type: ignore[call-arg]
276+
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
277+
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
278+
build_rescale(
279+
tosa_fb,
280+
[rescale_scale],
281+
input_node,
282+
output_name=output_name,
283+
output_type=ts.DType.INT8,
284+
input_zp=[0],
285+
output_zp=[output_zp],
286+
rounding_mode=RoundingMode.SINGLE_ROUND,
287+
) # type: ignore[call-arg]
288+
299289
return
300290

301291

@@ -318,18 +308,17 @@ def build_rescale_conv_output(
318308
(inp * w) / out for inp, w, out in zip(input_scale, weight_scale, output_scale)
319309
]
320310

321-
if isinstance(tosa_spec[0], tosa_specification.Tosa_1_00):
322-
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
323-
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
324-
build_rescale(
325-
tosa_fb=tosa_fb,
326-
scale=post_conv2d_scale,
327-
input_node=op,
328-
output_name=output_name,
329-
output_type=output_type,
330-
input_zp=[0],
331-
output_zp=output_zp,
332-
rounding_mode=RoundingMode.SINGLE_ROUND,
333-
per_channel=isinstance(weight_scale, torch.Tensor),
334-
) # type: ignore[call-arg]
311+
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
312+
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
313+
build_rescale(
314+
tosa_fb=tosa_fb,
315+
scale=post_conv2d_scale,
316+
input_node=op,
317+
output_name=output_name,
318+
output_type=output_type,
319+
input_zp=[0],
320+
output_zp=output_zp,
321+
rounding_mode=RoundingMode.SINGLE_ROUND,
322+
per_channel=isinstance(weight_scale, torch.Tensor),
323+
) # type: ignore[call-arg]
335324
return

backends/arm/tosa_utils.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from executorch.backends.arm.tosa_mapping import extract_tensor_meta, TosaArg
2020

21-
from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification
21+
from executorch.backends.arm.tosa_specification import TosaSpecification
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.print_program import inspect_node
2424

@@ -169,14 +169,6 @@ def broadcast_tensors(
169169
for broadcast. However this function also performs the broadcast and
170170
does not have a limit on only two input tensors.
171171
"""
172-
173-
if isinstance(tosa_spec, Tosa_1_00):
174-
import serializer.tosa_serializer as ts
175-
176-
reshape_helper = build_reshape_tosa_1_0
177-
else:
178-
raise ValueError(f"Unsupported TOSA spec: {tosa_spec}")
179-
180172
index_fake_tensors = [node.meta["val"] for node in nodes]
181173
broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors)
182174
if not broadcastable:
@@ -198,26 +190,25 @@ def broadcast_tensors(
198190
tens_dtype,
199191
)
200192

201-
reshape_helper(tosa_fb, node.name, new_shape, reshaped.name)
193+
build_reshape_tosa_1_0(tosa_fb, node.name, new_shape, reshaped.name)
202194

203195
tiled = tosa_fb.addIntermediate(common_shape, tens_dtype)
204196
multipliers = [
205197
comm if curr == 1 else 1 for comm, curr in zip(common_shape, new_shape)
206198
]
207-
if isinstance(tosa_spec, Tosa_1_00):
208-
multiple_shapes = tosa_fb.addConst(
209-
(len(multipliers),),
210-
ts.DType.SHAPE,
211-
multipliers,
212-
name=f"{node.name}_multiples",
213-
)
199+
multiple_shapes = tosa_fb.addConst(
200+
(len(multipliers),),
201+
ts.DType.SHAPE,
202+
multipliers,
203+
name=f"{node.name}_multiples",
204+
)
214205

215-
tosa_fb.addOperator(
216-
ts.TosaOp.Op().TILE,
217-
[reshaped.name, multiple_shapes.name],
218-
[tiled.name],
219-
None,
220-
)
206+
tosa_fb.addOperator(
207+
ts.TosaOp.Op().TILE,
208+
[reshaped.name, multiple_shapes.name],
209+
[tiled.name],
210+
None,
211+
)
221212

222213
broadcast_tensors.append(tiled)
223214

@@ -227,19 +218,17 @@ def broadcast_tensors(
227218
def build_reshape_tosa_1_0(
228219
tosa_graph, input_name, new_shape, output_name, shape_name_override=""
229220
):
230-
import serializer.tosa_serializer as ts_ # type: ignore
231-
232221
shape = tosa_graph.addConst(
233222
np.array(new_shape).shape,
234-
ts_.DType.SHAPE,
223+
ts.DType.SHAPE,
235224
np.array(new_shape),
236225
name=shape_name_override if shape_name_override else output_name + "_shape",
237226
)
238227

239-
attr = ts_.TosaSerializerAttribute()
228+
attr = ts.TosaSerializerAttribute()
240229
attr.ReshapeAttribute()
241230
tosa_graph.addOperator(
242-
ts_.TosaOp.Op().RESHAPE,
231+
ts.TosaOp.Op().RESHAPE,
243232
[input_name, shape.name],
244233
[output_name],
245234
attr,

0 commit comments

Comments
 (0)