Skip to content

Commit 5daaed4

Browse files
committed
NXP backend: Add pre-processing pass to move view_copy nodes into their own QDQ clusters.
A Pytorch model can contain a `Linear` operator with 4D IO. After quantization, it gets its own QDQ cluster. But after lowering to edge, `view_copy` operators are added within the cluster, right before and after the `Linear` (now `addmm`/`mm`). This does not follow the QDQ schema and causes issues later down the pipeline. Therefore, pre-processing passes at the edge dialect level were implemented, to move the `view_copy` nodes into their own QDQ clusters.
1 parent 5ff9157 commit 5daaed4

File tree

7 files changed

+363
-27
lines changed

7 files changed

+363
-27
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
9+
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from torch.fx import Node
12+
from torch.fx.passes.infra.pass_base import PassResult
13+
14+
15+
def insert_qdq_pair_after_node(
16+
graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple
17+
):
18+
# Insert a Quantize node.
19+
with graph.inserting_after(anchor):
20+
quantize_op = graph.create_node(
21+
op="call_function",
22+
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
23+
args=(), # Will be added later.
24+
)
25+
quantize_op.meta = anchor.meta
26+
27+
# Insert a Dequantize node.
28+
with graph.inserting_after(quantize_op):
29+
dequantize_op = graph.create_node(
30+
op="call_function",
31+
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
32+
args=(quantize_op,) + q_params,
33+
)
34+
dequantize_op.meta = quantize_op.meta
35+
anchor.replace_all_uses_with(dequantize_op)
36+
37+
# Add this at the end, so the `anchor.replace_all_uses_with(dequantize_op)` does not replace the first use of the
38+
# `quantize_op`.
39+
quantize_op.args = (anchor,) + q_params
40+
41+
42+
def _is_dequantize(node_: Node) -> bool:
43+
return (
44+
node_.op == "call_function"
45+
and node_.target
46+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
47+
)
48+
49+
50+
def _is_quantize(node_: Node) -> bool:
51+
return (
52+
node_.op == "call_function"
53+
and node_.target
54+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
55+
)
56+
57+
58+
class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
59+
"""
60+
61+
┌─────▼──────┐
62+
│ │ dequantize │
63+
┌─────▼──────┐ └─────┬──────┘
64+
│ dequantize │ ┌─────▼──────┐
65+
└─────┬──────┘ │ <aux_node> │
66+
┌─────▼──────┐ └─────┬──────┘
67+
│ <aux_node> │ ┌────▼─────┐ ┐
68+
└─────┬──────┘ │ quantize │ │
69+
┌──────────▼──────────┐ replaced with └────┬─────┘ │
70+
⋯┤ <main_cluster_node> ├⋯ ──────────────► │ │ newly added nodes
71+
└──────────┬──────────┘ ┌─────▼──────┐ │
72+
▼ │ dequantize │ │
73+
⋮ └─────┬──────┘ ┘
74+
┌────▼─────┐ ┌──────────▼──────────┐
75+
│ quantize │ ⋯┤ <main_cluster_node> ├⋯
76+
└────┬─────┘ └──────────┬──────────┘
77+
▼ ▼
78+
79+
┌────▼─────┐
80+
│ quantize │
81+
└────┬─────┘
82+
83+
"""
84+
85+
allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]
86+
87+
# List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification.
88+
allowed_main_cluster_nodes = [
89+
exir_ops.edge.aten.addmm.default,
90+
exir_ops.edge.aten.mm.default,
91+
]
92+
93+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
94+
for aux_node in graph_module.graph.nodes:
95+
if (
96+
aux_node.op != "call_function"
97+
or aux_node.target not in self.allowed_auxiliary_nodes
98+
):
99+
continue
100+
101+
dequantize_node = aux_node.args[0]
102+
if not _is_dequantize(dequantize_node):
103+
# Not the intended use case.
104+
continue
105+
106+
users = list(aux_node.users.keys())
107+
if len(users) != 1:
108+
# Not the intended use case.
109+
continue
110+
111+
main_cluster_node = users[0]
112+
if (
113+
main_cluster_node.op != "call_function"
114+
or main_cluster_node.target not in self.allowed_main_cluster_nodes
115+
):
116+
# Unsupported `main_cluster_node`.
117+
continue
118+
119+
# Make sure the nodes are part of the same QDQ cluster.
120+
cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node)
121+
if any(
122+
node_ not in cluster
123+
for node_ in [dequantize_node, aux_node, main_cluster_node]
124+
):
125+
continue
126+
127+
# ---- The nodes follow the pattern described in the header. ----
128+
129+
q_params = dequantize_node.args[1:]
130+
insert_qdq_pair_after_node(graph_module.graph, aux_node, q_params)
131+
132+
# The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent
133+
# class will call this pass again.
134+
return PassResult(graph_module, True)
135+
136+
# Nothing was changed.
137+
return PassResult(graph_module, False)
138+
139+
140+
class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
141+
"""
142+
143+
┌─────▼──────┐
144+
│ │ dequantize │
145+
┌─────▼──────┐ └─────┬──────┘
146+
│ dequantize │ ⋮
147+
└─────┬──────┘ ┌──────────▼──────────┐
148+
▼ ⋯┤ <main_cluster_node> ├⋯
149+
⋮ └──────────┬──────────┘
150+
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
151+
⋯┤ <main_cluster_node> ├⋯ ──────────────► │ quantize │ │
152+
└──────────┬──────────┘ └────┬─────┘ │
153+
┌─────▼──────┐ │ │ newly added nodes
154+
│ <aux_node> │ ┌─────▼──────┐ │
155+
└─────┬──────┘ │ dequantize │ │
156+
┌────▼─────┐ └─────┬──────┘ ┘
157+
│ quantize │ ┌─────▼──────┐
158+
└────┬─────┘ │ <aux_node> │
159+
▼ └─────┬──────┘
160+
┌────▼─────┐
161+
│ quantize │
162+
└────┬─────┘
163+
164+
"""
165+
166+
allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]
167+
168+
# List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification.
169+
allowed_main_cluster_nodes = [
170+
exir_ops.edge.aten.addmm.default,
171+
exir_ops.edge.aten.mm.default,
172+
]
173+
174+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
175+
176+
for aux_node in graph_module.graph.nodes:
177+
if (
178+
aux_node.op != "call_function"
179+
or aux_node.target not in self.allowed_auxiliary_nodes
180+
):
181+
continue
182+
183+
main_cluster_node = aux_node.args[0]
184+
if (
185+
main_cluster_node.op != "call_function"
186+
or main_cluster_node.target not in self.allowed_main_cluster_nodes
187+
):
188+
# Unsupported `main_cluster_node`.
189+
continue
190+
191+
users = list(aux_node.users.keys())
192+
if len(users) != 1:
193+
# Not the intended use case.
194+
continue
195+
196+
quantize_node = users[0]
197+
if not _is_quantize(quantize_node):
198+
# Not the intended use case.
199+
continue
200+
201+
# Make sure the nodes are part of the same QDQ cluster.
202+
cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node)
203+
if any(
204+
node_ not in cluster
205+
for node_ in [quantize_node, aux_node, main_cluster_node]
206+
):
207+
continue
208+
209+
# ---- The nodes follow the pattern described in the header. ----
210+
211+
q_params = quantize_node.args[1:]
212+
insert_qdq_pair_after_node(graph_module.graph, main_cluster_node, q_params)
213+
214+
# The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent
215+
# class will call this pass again.
216+
return PassResult(graph_module, True)
217+
218+
# Nothing was changed.
219+
return PassResult(graph_module, False)

backends/nxp/edge_passes/neutron_edge_pass.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
from abc import abstractmethod
88

99
import torch
10-
from torch.fx.passes.infra.pass_base import PassResult
1110

1211
from executorch.exir.pass_base import ExportPass
12+
from torch.fx.passes.infra.pass_base import PassResult
1313

1414

1515
class NeutronEdgePass(ExportPass):
16-
""" Abstract parent class for pre-processing passes on the edge dialect level. """
16+
"""Abstract parent class for pre-processing passes on the edge dialect level."""
1717

1818
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
19-
""" Call `self.run()` as long as changes are being made. After a pass modifies the graph, it cannot keep on
20-
iterating through its nodes, and must return. This method allows the pass to go through the whole model.
19+
"""Call `self.run()` as long as changes are being made. After a pass modifies the graph, it cannot keep on
20+
iterating through its nodes, and must return. This method allows the pass to go through the whole model.
2121
"""
2222

2323
# Every pass will return once it makes a change to the graph, to avoid traversing and modifying a graph at the
@@ -36,19 +36,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3636
return PassResult(graph_module, modified)
3737

3838
# Iteration limit was reached.
39-
logging.warning(f'The NeutronEdgePass `{self.__class__.__name__}` reached the iteration limit.')
39+
logging.warning(
40+
f"The NeutronEdgePass `{self.__class__.__name__}` reached the iteration limit."
41+
)
4042
graph_module = self.recompile_module(graph_module)
4143
return PassResult(graph_module, modified)
4244

4345
@abstractmethod
4446
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
45-
""" Child classes should implement their graph modification here. """
47+
"""Child classes should implement their graph modification here."""
4648
pass
4749

4850
def recompile_module(
4951
self, graph_module: torch.fx.GraphModule
5052
) -> torch.fx.GraphModule:
51-
""" Recompile the graph and re-trace the metadata. This should ensure that the datatypes and shapes are correct.
52-
"""
53+
"""Recompile the graph and re-trace the metadata. This should ensure that the datatypes and shapes are correct."""
5354
graph_module.recompile()
5455
return super().call(graph_module).graph_module

backends/nxp/edge_passes/neutron_edge_pass_manager.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,38 @@
55

66
import copy
77

8-
from torch import nn
9-
from torch.export import ExportedProgram
10-
from torch.fx.passes.infra.pass_base import PassResult
11-
from torch.fx.passes.infra.pass_manager import PassManager
12-
138
from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import (
149
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass,
1510
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
1611
)
1712
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
1813
from executorch.exir import EdgeProgramManager
19-
from executorch.exir.program._program import _get_updated_graph_signature, _get_updated_range_constraints
14+
from executorch.exir.program._program import (
15+
_get_updated_graph_signature,
16+
_get_updated_range_constraints,
17+
)
18+
19+
from torch import nn
20+
from torch.export import ExportedProgram
21+
from torch.fx.passes.infra.pass_base import PassResult
22+
from torch.fx.passes.infra.pass_manager import PassManager
2023

2124

2225
class NeutronEdgePassManager(PassManager):
2326

2427
def __init__(self, passes: list[NeutronEdgePass] = None):
2528
passes: list[NeutronEdgePass] = passes or [
29+
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
30+
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
2631
]
2732

2833
super().__init__(
2934
passes,
30-
steps=10 # Empirical value. At most 10 cycles of passes will be run.
35+
steps=10, # Empirical value. At most 10 cycles of passes will be run.
3136
)
3237

3338
def _transform_graph_module(self, module: nn.Module) -> PassResult:
34-
""" Apply the passes to a single graph module. """
39+
"""Apply the passes to a single graph module."""
3540
pass_result: PassResult = super().__call__(module)
3641

3742
graph_module = pass_result.graph_module
@@ -41,7 +46,7 @@ def _transform_graph_module(self, module: nn.Module) -> PassResult:
4146
return pass_result
4247

4348
def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:
44-
""" Apply the passes to all graph modules in the edge program. """
49+
"""Apply the passes to all graph modules in the edge program."""
4550
new_programs: dict[str, ExportedProgram] = {}
4651

4752
for name, program in epm._edge_programs.items():
@@ -56,7 +61,9 @@ def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:
5661
program.graph_signature, pass_result.graph_module
5762
),
5863
state_dict=program.state_dict,
59-
range_constraints=_get_updated_range_constraints(pass_result.graph_module),
64+
range_constraints=_get_updated_range_constraints(
65+
pass_result.graph_module
66+
),
6067
module_call_graph=copy.deepcopy(program._module_call_graph),
6168
example_inputs=program.example_inputs,
6269
constants=program.constants,
@@ -77,4 +84,6 @@ def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:
7784

7885
else:
7986
# Return a new EdgeProgramManager with the updated programs.
80-
return EdgeProgramManager(new_programs, copy.deepcopy(epm._config_methods), epm.compile_config)
87+
return EdgeProgramManager(
88+
new_programs, copy.deepcopy(epm._config_methods), epm.compile_config
89+
)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7-
from torch import nn
8-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
97

108
from executorch import exir
11-
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import NeutronEdgePassManager
9+
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
10+
NeutronEdgePassManager,
11+
)
1212
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1313
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
1414
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -19,6 +19,8 @@
1919
ExecutorchProgramManager,
2020
)
2121
from executorch.extension.export_util.utils import export_to_edge
22+
from torch import nn
23+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2224

2325

2426
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):

backends/nxp/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,24 @@ def forward(self, x):
125125
return x
126126

127127

128+
class ConvFCFCSoftmaxModuleWithoutReshape(torch.nn.Module):
129+
def __init__(self):
130+
super().__init__()
131+
132+
self.conv = torch.nn.Conv2d(4, 5, 2, bias=False)
133+
self.fc1 = torch.nn.Linear(32, 16)
134+
self.fc2 = torch.nn.Linear(16, 8)
135+
self.softmax = torch.nn.Softmax(1)
136+
137+
def forward(self, x):
138+
x = self.conv(x)
139+
x = self.fc1(x)
140+
x = self.fc2(x)
141+
x = self.softmax(x)
142+
143+
return x
144+
145+
128146
class ConstantPadNDModule(torch.nn.Module):
129147
def __init__(self, paddings: Collection[int], constant: float | int | None = None):
130148
super().__init__()

0 commit comments

Comments
 (0)