Skip to content

Commit b36d6b6

Browse files
authored
Enable strongly typed ops for deployment
Differential Revision: D79867630 Pull Request resolved: #13230
1 parent 5aa127e commit b36d6b6

11 files changed

+471
-4
lines changed

backends/cadence/aot/TARGETS

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ python_library(
101101
":reorder_ops",
102102
":replace_ops",
103103
":simplify_ops",
104+
":type_dispatch",
104105
":utils",
105106
"//caffe2:torch",
106107
"//executorch/exir:pass_base",
@@ -322,6 +323,37 @@ python_library(
322323
],
323324
)
324325

326+
python_library(
327+
name = "type_dispatch",
328+
srcs = [
329+
"type_dispatch.py",
330+
],
331+
typing = True,
332+
deps = [
333+
"//caffe2:torch",
334+
"//executorch/backends/cadence/aot:pass_utils",
335+
"//executorch/exir:pass_base",
336+
],
337+
)
338+
339+
python_unittest(
340+
name = "test_type_dispatch_passes",
341+
srcs = [
342+
"tests/test_type_dispatch_passes.py",
343+
],
344+
supports_static_listing = False,
345+
typing = True,
346+
deps = [
347+
":ops_registrations",
348+
":type_dispatch",
349+
"//caffe2:torch",
350+
"//executorch/backends/cadence/aot:graph_builder",
351+
"//executorch/backends/cadence/aot:pass_utils",
352+
"//executorch/exir:pass_base",
353+
"//executorch/exir/dialects:lib",
354+
],
355+
)
356+
325357
python_library(
326358
name = "typing_stubs",
327359
srcs = [

backends/cadence/aot/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,16 @@
254254
- arg_meta: null
255255
kernel_name: impl::reference::quantized_fully_connected_per_tensor_out
256256

257+
- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
258+
kernels:
259+
- arg_meta: null
260+
kernel_name: impl::reference::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out
261+
262+
- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
263+
kernels:
264+
- arg_meta: null
265+
kernel_name: impl::reference::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out
266+
257267
- func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)
258268
kernels:
259269
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,17 +329,27 @@
329329
- arg_meta: null
330330
kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out
331331

332-
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
332+
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
333333
kernels:
334334
- arg_meta: null
335-
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
335+
kernel_name: cadence::impl::HiFi::quantized_matmul_out
336336

337-
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
337+
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
338338
kernels:
339339
- arg_meta: null
340-
kernel_name: cadence::impl::HiFi::quantized_matmul_out
340+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
341341

342342
- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
343343
kernels:
344344
- arg_meta: null
345345
kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out
346+
347+
- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
348+
kernels:
349+
- arg_meta: null
350+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out
351+
352+
- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
353+
kernels:
354+
- arg_meta: null
355+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out

backends/cadence/aot/ops_registrations.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@
162162
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
163163
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
164164
)
165+
lib.define(
166+
"quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
167+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
168+
)
169+
lib.define(
170+
"quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
171+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
172+
)
165173
lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)")
166174
lib.define(
167175
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
@@ -240,6 +248,14 @@
240248
"quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
241249
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
242250
)
251+
lib.define(
252+
"quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
253+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
254+
)
255+
lib.define(
256+
"quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
257+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
258+
)
243259
lib.define(
244260
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
245261
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -754,6 +770,50 @@ def quantized_fully_connected_per_tensor_meta(
754770
return src.new_empty(out_size, dtype=src.dtype)
755771

756772

773+
@register_fake("cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor")
774+
def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta(
775+
src: torch.Tensor,
776+
weight: torch.Tensor,
777+
bias: torch.Tensor,
778+
in_zero_point: int,
779+
weight_zero_point: int,
780+
out_multiplier: int,
781+
out_shift: int,
782+
out_zero_point: int,
783+
offset: Optional[torch.Tensor],
784+
) -> torch.Tensor:
785+
# src comes in shape [leading_dims, in_dim]
786+
# weight comes in shape [out_dim, in_dim]
787+
# output comes in empty with shape [leading_dims, out_dim]
788+
out_size = list(src.size())
789+
weight_size = list(weight.size())
790+
assert len(weight_size) == 2
791+
out_size[-1] = weight_size[0]
792+
return src.new_empty(out_size, dtype=src.dtype)
793+
794+
795+
@register_fake("cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor")
796+
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
797+
src: torch.Tensor,
798+
weight: torch.Tensor,
799+
bias: torch.Tensor,
800+
in_zero_point: int,
801+
weight_zero_point: int,
802+
out_multiplier: int,
803+
out_shift: int,
804+
out_zero_point: int,
805+
offset: Optional[torch.Tensor],
806+
) -> torch.Tensor:
807+
# src comes in shape [leading_dims, in_dim]
808+
# weight comes in shape [out_dim, in_dim]
809+
# output comes in empty with shape [leading_dims, out_dim]
810+
out_size = list(src.size())
811+
weight_size = list(weight.size())
812+
assert len(weight_size) == 2
813+
out_size[-1] = weight_size[0]
814+
return src.new_empty(out_size, dtype=src.dtype)
815+
816+
757817
@register_fake("cadence::convolution")
758818
def convolution_meta(
759819
input: torch.Tensor,

backends/cadence/aot/passes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ReplaceMulTensorWithMulAndFullOpsPass,
3434
)
3535
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass
3637
from executorch.exir import EdgeProgramManager
3738
from executorch.exir.pass_base import ExportPass, PassResult
3839
from executorch.exir.pass_manager import PassManager, PassType
@@ -90,6 +91,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
9091
FuseFullThenReshapePass,
9192
FuseTransposeOrPermuteOpPairsPass,
9293
RemoveNopSliceOrViewOpPass,
94+
CompileTimeTypeDispatchPass,
9395
]
9496
return pytree.tree_flatten(passes)[0]
9597

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# pyre-strict
7+
8+
import unittest
9+
from typing import cast
10+
11+
import executorch.backends.cadence.aot.ops_registrations # noqa
12+
import torch
13+
from executorch.backends.cadence.aot.graph_builder import single_op_builder
14+
from executorch.backends.cadence.aot.pass_utils import count_node
15+
from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from torch.fx.passes.infra.pass_base import PassResult
18+
19+
20+
class TestTypeDispatchPasses(unittest.TestCase):
21+
def test_int8_dispatch(self) -> None:
22+
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant"""
23+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
24+
w = torch.randint(-128, 127, (4, 3), dtype=torch.int8)
25+
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
26+
gm = single_op_builder(
27+
placeholders=(x, w, b),
28+
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
29+
args=(x, w, b, 0, 0, 1, 0, 0, None),
30+
)
31+
p = CompileTimeTypeDispatchPass()
32+
gm = cast(PassResult, p(gm)).graph_module
33+
# Original op should be replaced
34+
self.assertEqual(
35+
count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor),
36+
0,
37+
)
38+
# Should be replaced with int8 specific variant
39+
self.assertEqual(
40+
count_node(
41+
gm,
42+
exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
43+
),
44+
1,
45+
)
46+
47+
def test_uint8_dispatch(self) -> None:
48+
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant"""
49+
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
50+
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
51+
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
52+
gm = single_op_builder(
53+
placeholders=(x, w, b),
54+
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
55+
args=(x, w, b, 0, 0, 1, 0, 0, None),
56+
)
57+
p = CompileTimeTypeDispatchPass()
58+
gm = cast(PassResult, p(gm)).graph_module
59+
# Original op should be replaced
60+
self.assertEqual(
61+
count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor),
62+
0,
63+
)
64+
# Should be replaced with uint8 specific variant
65+
self.assertEqual(
66+
count_node(
67+
gm,
68+
exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
69+
),
70+
1,
71+
)
72+
73+
def test_mixed_types_error(self) -> None:
74+
"""Test mixed int8/uint8 inputs should raise RuntimeError"""
75+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
76+
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
77+
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
78+
gm = single_op_builder(
79+
placeholders=(x, w, b),
80+
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
81+
args=(x, w, b, 0, 0, 1, 0, 0, None),
82+
)
83+
p = CompileTimeTypeDispatchPass()
84+
# Mixed types should raise RuntimeError
85+
with self.assertRaises(RuntimeError) as context:
86+
cast(PassResult, p(gm)).graph_module
87+
self.assertIn("Unsupported input types", str(context.exception))

backends/cadence/aot/type_dispatch.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.backends.cadence.aot.pass_utils import (
11+
CadencePassAttribute,
12+
register_cadence_pass,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
16+
from torch._ops import OpOverload
17+
from torch.fx.node import Argument
18+
19+
20+
@register_cadence_pass(CadencePassAttribute(opt_level=4))
21+
class CompileTimeTypeDispatchPass(ExportPass):
22+
"""
23+
Replaces generic ops with ops that have explicit types.
24+
"""
25+
26+
def call_operator(
27+
self,
28+
op: OpOverload,
29+
args: tuple[Argument, ...],
30+
kwargs: dict[str, Argument],
31+
meta: NodeMetadata,
32+
) -> ProxyValue:
33+
if op not in {
34+
exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
35+
}:
36+
return super().call_operator(op, args, kwargs, meta)
37+
38+
if (
39+
# pyre-ignore[16]: None has no attribute `to_tensor`.
40+
args[0].to_tensor().dtype == torch.int8
41+
and args[1].to_tensor().dtype == torch.int8
42+
):
43+
return super().call_operator(
44+
exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
45+
args,
46+
kwargs,
47+
meta,
48+
)
49+
elif (
50+
args[0].to_tensor().dtype == torch.uint8
51+
and args[1].to_tensor().dtype == torch.uint8
52+
):
53+
return super().call_operator(
54+
exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
55+
args,
56+
kwargs,
57+
meta,
58+
)
59+
else:
60+
raise RuntimeError(
61+
f"Unsupported input types for {op}: {args[0].to_tensor().dtype} and {args[1].to_tensor().dtype}"
62+
)

0 commit comments

Comments
 (0)