Skip to content

Arm backend: Remove instance checks for Tosa_1.00 #13219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from typing import Any, cast, Dict

import numpy as np
import serializer.tosa_serializer as ts
import torch
import torch.fx
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from torch._export.utils import (
get_buffer,
Expand Down Expand Up @@ -81,11 +82,6 @@ def process_inputs(
"Is the original torch function supported?"
) from e

if isinstance(tosa_spec, Tosa_1_00):
import serializer.tosa_serializer as ts
else:
raise ValueError(f"Unsupported TOSA spec: {tosa_spec}")

input_shape = tosa_arg.shape
input_dim_order = tosa_arg.dim_order
tensor = ts.TosaSerializerTensor(
Expand Down
10 changes: 1 addition & 9 deletions backends/arm/tosa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import logging
from typing import cast, final, List

import executorch.backends.arm.tosa_specification as tosa_specification

import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.arm_backend import get_tosa_spec
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm._passes import (
Expand Down Expand Up @@ -85,13 +84,6 @@ def preprocess( # noqa: C901

# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
if isinstance(tosa_spec, tosa_specification.Tosa_1_00):
import serializer.tosa_serializer as ts # type: ignore
else:
raise RuntimeError(
f"Unknown TOSA version {tosa_spec}, no pip package installed to handle serialization to that version."
)

tosa_graph = ts.TosaSerializer(artifact_path)

assert (
Expand Down
12 changes: 3 additions & 9 deletions backends/arm/tosa_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

from typing import Any, Optional, Sequence

import serializer.tosa_serializer as ts # type: ignore

import torch
from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification
from executorch.backends.arm.tosa_specification import TosaSpecification

UNSUPPORTED_DTYPES = (
torch.float64,
Expand All @@ -32,10 +34,6 @@
def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
if data_type in UNSUPPORTED_DTYPES:
raise ValueError(f"Unsupported type: {data_type}")
if isinstance(tosa_spec, Tosa_1_00):
import serializer.tosa_serializer as ts # type: ignore
else:
raise RuntimeError(f"Unsupported tosa_spec: {tosa_spec}")

dtype_map = {
torch.float32: ts.DType.FP32,
Expand Down Expand Up @@ -134,10 +132,6 @@ def __repr__(self):
if self.name is not None:
attrs.append(f"name={self.name!r}")
if self.dtype is not None:
if isinstance(self.tosa_spec, Tosa_1_00):
import serializer.tosa_serializer as ts # type: ignore
else:
raise RuntimeError(f"Unsupported tosa_spec: {self.tosa_spec}")
attrs.append(f"dtype={ts.DTypeNames[self.dtype]}")
if self.shape is not None:
attrs.append(f"shape={self.shape!r}")
Expand Down
87 changes: 38 additions & 49 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

from typing import Any, Tuple

import executorch.backends.arm.tosa_specification as tosa_specification

import serializer.tosa_serializer as ts # type: ignore
import torch.fx
import torch.fx.node

Expand Down Expand Up @@ -247,25 +246,18 @@ def build_rescale_to_int32(
) -> Any:
input_A_rescaled_to_int32 = None

if isinstance(tosa_spec, tosa_specification.Tosa_1_00):
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
import serializer.tosa_serializer as ts # type: ignore

input_A_rescaled_to_int32 = tosa_fb.addIntermediate(
input_arg.shape, ts.DType.INT32
)
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32)

build_rescale(
tosa_fb,
[rescale_scale],
input_arg,
input_A_rescaled_to_int32.name,
ts.DType.INT32,
[input_zp],
[0],
rounding_mode=RoundingMode.SINGLE_ROUND,
) # type: ignore[call-arg]
build_rescale(
tosa_fb,
[rescale_scale],
input_arg,
input_A_rescaled_to_int32.name,
ts.DType.INT32,
[input_zp],
[0],
rounding_mode=RoundingMode.SINGLE_ROUND,
) # type: ignore[call-arg]

return input_A_rescaled_to_int32

Expand All @@ -281,21 +273,19 @@ def build_rescale_from_int32(
per_channel: bool = False,
tosa_spec=None,
) -> None:
if isinstance(tosa_spec, tosa_specification.Tosa_1_00):
import serializer.tosa_serializer as ts # type: ignore

# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
build_rescale(
tosa_fb,
[rescale_scale],
input_node,
output_name=output_name,
output_type=ts.DType.INT8,
input_zp=[0],
output_zp=[output_zp],
rounding_mode=RoundingMode.SINGLE_ROUND,
) # type: ignore[call-arg]
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
build_rescale(
tosa_fb,
[rescale_scale],
input_node,
output_name=output_name,
output_type=ts.DType.INT8,
input_zp=[0],
output_zp=[output_zp],
rounding_mode=RoundingMode.SINGLE_ROUND,
) # type: ignore[call-arg]

return


Expand All @@ -318,18 +308,17 @@ def build_rescale_conv_output(
(inp * w) / out for inp, w, out in zip(input_scale, weight_scale, output_scale)
]

if isinstance(tosa_spec[0], tosa_specification.Tosa_1_00):
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
build_rescale(
tosa_fb=tosa_fb,
scale=post_conv2d_scale,
input_node=op,
output_name=output_name,
output_type=output_type,
input_zp=[0],
output_zp=output_zp,
rounding_mode=RoundingMode.SINGLE_ROUND,
per_channel=isinstance(weight_scale, torch.Tensor),
) # type: ignore[call-arg]
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
build_rescale(
tosa_fb=tosa_fb,
scale=post_conv2d_scale,
input_node=op,
output_name=output_name,
output_type=output_type,
input_zp=[0],
output_zp=output_zp,
rounding_mode=RoundingMode.SINGLE_ROUND,
per_channel=isinstance(weight_scale, torch.Tensor),
) # type: ignore[call-arg]
return
45 changes: 17 additions & 28 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from executorch.backends.arm.tosa_mapping import extract_tensor_meta, TosaArg

from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.print_program import inspect_node

Expand Down Expand Up @@ -169,14 +169,6 @@ def broadcast_tensors(
for broadcast. However this function also performs the broadcast and
does not have a limit on only two input tensors.
"""

if isinstance(tosa_spec, Tosa_1_00):
import serializer.tosa_serializer as ts

reshape_helper = build_reshape_tosa_1_0
else:
raise ValueError(f"Unsupported TOSA spec: {tosa_spec}")

index_fake_tensors = [node.meta["val"] for node in nodes]
broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors)
if not broadcastable:
Expand All @@ -198,26 +190,25 @@ def broadcast_tensors(
tens_dtype,
)

reshape_helper(tosa_fb, node.name, new_shape, reshaped.name)
build_reshape_tosa_1_0(tosa_fb, node.name, new_shape, reshaped.name)

tiled = tosa_fb.addIntermediate(common_shape, tens_dtype)
multipliers = [
comm if curr == 1 else 1 for comm, curr in zip(common_shape, new_shape)
]
if isinstance(tosa_spec, Tosa_1_00):
multiple_shapes = tosa_fb.addConst(
(len(multipliers),),
ts.DType.SHAPE,
multipliers,
name=f"{node.name}_multiples",
)
multiple_shapes = tosa_fb.addConst(
(len(multipliers),),
ts.DType.SHAPE,
multipliers,
name=f"{node.name}_multiples",
)

tosa_fb.addOperator(
ts.TosaOp.Op().TILE,
[reshaped.name, multiple_shapes.name],
[tiled.name],
None,
)
tosa_fb.addOperator(
ts.TosaOp.Op().TILE,
[reshaped.name, multiple_shapes.name],
[tiled.name],
None,
)

broadcast_tensors.append(tiled)

Expand All @@ -227,19 +218,17 @@ def broadcast_tensors(
def build_reshape_tosa_1_0(
tosa_graph, input_name, new_shape, output_name, shape_name_override=""
):
import serializer.tosa_serializer as ts_ # type: ignore

shape = tosa_graph.addConst(
np.array(new_shape).shape,
ts_.DType.SHAPE,
ts.DType.SHAPE,
np.array(new_shape),
name=shape_name_override if shape_name_override else output_name + "_shape",
)

attr = ts_.TosaSerializerAttribute()
attr = ts.TosaSerializerAttribute()
attr.ReshapeAttribute()
tosa_graph.addOperator(
ts_.TosaOp.Op().RESHAPE,
ts.TosaOp.Op().RESHAPE,
[input_name, shape.name],
[output_name],
attr,
Expand Down
Loading