diff --git a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py new file mode 100644 index 00000000000..7eba60cf2ec --- /dev/null +++ b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py @@ -0,0 +1,219 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass +from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import Node +from torch.fx.passes.infra.pass_base import PassResult + + +def insert_qdq_pair_after_node( + graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple +): + # Insert a Quantize node. + with graph.inserting_after(anchor): + quantize_op = graph.create_node( + op="call_function", + target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(), # Will be added later. + ) + quantize_op.meta = anchor.meta + + # Insert a Dequantize node. + with graph.inserting_after(quantize_op): + dequantize_op = graph.create_node( + op="call_function", + target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(quantize_op,) + q_params, + ) + dequantize_op.meta = quantize_op.meta + anchor.replace_all_uses_with(dequantize_op) + + # Add this at the end, so the `anchor.replace_all_uses_with(dequantize_op)` does not replace the first use of the + # `quantize_op`. + quantize_op.args = (anchor,) + q_params + + +def _is_dequantize(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target + == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ) + + +def _is_quantize(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ) + + +class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): + """ + │ + ┌─────▼──────┐ + │ │ dequantize │ + ┌─────▼──────┐ └─────┬──────┘ + │ dequantize │ ┌─────▼──────┐ + └─────┬──────┘ │ │ + ┌─────▼──────┐ └─────┬──────┘ + │ │ ┌────▼─────┐ ┐ + └─────┬──────┘ │ quantize │ │ + ┌──────────▼──────────┐ replaced with └────┬─────┘ │ + ⋯┤ ├⋯ ──────────────► │ │ newly added nodes + └──────────┬──────────┘ ┌─────▼──────┐ │ + ▼ │ dequantize │ │ + ⋮ └─────┬──────┘ ┘ + ┌────▼─────┐ ┌──────────▼──────────┐ + │ quantize │ ⋯┤ ├⋯ + └────┬─────┘ └──────────┬──────────┘ + ▼ ▼ + ⋮ + ┌────▼─────┐ + │ quantize │ + └────┬─────┘ + ▼ + """ + + allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default] + + # List of approved nodes to which the can be connected in order for the pass to make the modification. + allowed_main_cluster_nodes = [ + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.mm.default, + ] + + def run(self, graph_module: torch.fx.GraphModule) -> PassResult: + for aux_node in graph_module.graph.nodes: + if ( + aux_node.op != "call_function" + or aux_node.target not in self.allowed_auxiliary_nodes + ): + continue + + dequantize_node = aux_node.args[0] + if not _is_dequantize(dequantize_node): + # Not the intended use case. + continue + + users = list(aux_node.users.keys()) + if len(users) != 1: + # Not the intended use case. + continue + + main_cluster_node = users[0] + if ( + main_cluster_node.op != "call_function" + or main_cluster_node.target not in self.allowed_main_cluster_nodes + ): + # Unsupported `main_cluster_node`. + continue + + # Make sure the nodes are part of the same QDQ cluster. + cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node) + if any( + node_ not in cluster + for node_ in [dequantize_node, aux_node, main_cluster_node] + ): + continue + + # ---- The nodes follow the pattern described in the header. ---- + + q_params = dequantize_node.args[1:] + insert_qdq_pair_after_node(graph_module.graph, aux_node, q_params) + + # The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent + # class will call this pass again. + return PassResult(graph_module, True) + + # Nothing was changed. + return PassResult(graph_module, False) + + +class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): + """ + │ + ┌─────▼──────┐ + │ │ dequantize │ + ┌─────▼──────┐ └─────┬──────┘ + │ dequantize │ ⋮ + └─────┬──────┘ ┌──────────▼──────────┐ + ▼ ⋯┤ ├⋯ + ⋮ └──────────┬──────────┘ + ┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐ + ⋯┤ ├⋯ ──────────────► │ quantize │ │ + └──────────┬──────────┘ └────┬─────┘ │ + ┌─────▼──────┐ │ │ newly added nodes + │ │ ┌─────▼──────┐ │ + └─────┬──────┘ │ dequantize │ │ + ┌────▼─────┐ └─────┬──────┘ ┘ + │ quantize │ ┌─────▼──────┐ + └────┬─────┘ │ │ + ▼ └─────┬──────┘ + ┌────▼─────┐ + │ quantize │ + └────┬─────┘ + ▼ + """ + + allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default] + + # List of approved nodes to which the `` can be connected in order for the pass to make the modification. + allowed_main_cluster_nodes = [ + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.mm.default, + ] + + def run(self, graph_module: torch.fx.GraphModule) -> PassResult: + + for aux_node in graph_module.graph.nodes: + if ( + aux_node.op != "call_function" + or aux_node.target not in self.allowed_auxiliary_nodes + ): + continue + + main_cluster_node = aux_node.args[0] + if ( + main_cluster_node.op != "call_function" + or main_cluster_node.target not in self.allowed_main_cluster_nodes + ): + # Unsupported `main_cluster_node`. + continue + + users = list(aux_node.users.keys()) + if len(users) != 1: + # Not the intended use case. + continue + + quantize_node = users[0] + if not _is_quantize(quantize_node): + # Not the intended use case. + continue + + # Make sure the nodes are part of the same QDQ cluster. + cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node) + if any( + node_ not in cluster + for node_ in [quantize_node, aux_node, main_cluster_node] + ): + continue + + # ---- The nodes follow the pattern described in the header. ---- + + q_params = quantize_node.args[1:] + insert_qdq_pair_after_node(graph_module.graph, main_cluster_node, q_params) + + # The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent + # class will call this pass again. + return PassResult(graph_module, True) + + # Nothing was changed. + return PassResult(graph_module, False) diff --git a/backends/nxp/edge_passes/neutron_edge_pass.py b/backends/nxp/edge_passes/neutron_edge_pass.py new file mode 100644 index 00000000000..8f77ce022fc --- /dev/null +++ b/backends/nxp/edge_passes/neutron_edge_pass.py @@ -0,0 +1,55 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import abstractmethod + +import torch + +from executorch.exir.pass_base import ExportPass +from torch.fx.passes.infra.pass_base import PassResult + + +class NeutronEdgePass(ExportPass): + """Abstract parent class for pre-processing passes on the edge dialect level.""" + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + """Call `self.run()` as long as changes are being made. After a pass modifies the graph, it cannot keep on + iterating through its nodes, and must return. This method allows the pass to go through the whole model. + """ + + # Every pass will return once it makes a change to the graph, to avoid traversing and modifying a graph at the + # same time. Therefore, it must be called multiple times (at most `iteration_limit` times). + iteration_limit = len(graph_module.graph.nodes) + modified = False + for _ in range(iteration_limit): + res = self.run(graph_module) + if res.modified: + modified = True + graph_module = res.graph_module + + else: + # No more changes have been made. + graph_module = self.recompile_module(graph_module) + return PassResult(graph_module, modified) + + # Iteration limit was reached. + logging.warning( + f"The NeutronEdgePass `{self.__class__.__name__}` reached the iteration limit." + ) + graph_module = self.recompile_module(graph_module) + return PassResult(graph_module, modified) + + @abstractmethod + def run(self, graph_module: torch.fx.GraphModule) -> PassResult: + """Child classes should implement their graph modification here.""" + pass + + def recompile_module( + self, graph_module: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Recompile the graph and re-trace the metadata. This should ensure that the datatypes and shapes are correct.""" + graph_module.recompile() + return super().call(graph_module).graph_module diff --git a/backends/nxp/edge_passes/neutron_edge_pass_manager.py b/backends/nxp/edge_passes/neutron_edge_pass_manager.py new file mode 100644 index 00000000000..ec46070ac31 --- /dev/null +++ b/backends/nxp/edge_passes/neutron_edge_pass_manager.py @@ -0,0 +1,89 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import ( + MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass, + MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass, +) +from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass +from executorch.exir import EdgeProgramManager +from executorch.exir.program._program import ( + _get_updated_graph_signature, + _get_updated_range_constraints, +) + +from torch import nn +from torch.export import ExportedProgram +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager + + +class NeutronEdgePassManager(PassManager): + + def __init__(self, passes: list[NeutronEdgePass] = None): + passes: list[NeutronEdgePass] = passes or [ + MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(), + MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(), + ] + + super().__init__( + passes, + steps=10, # Empirical value. At most 10 cycles of passes will be run. + ) + + def _transform_graph_module(self, module: nn.Module) -> PassResult: + """Apply the passes to a single graph module.""" + pass_result: PassResult = super().__call__(module) + + graph_module = pass_result.graph_module + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return pass_result + + def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager: + """Apply the passes to all graph modules in the edge program.""" + new_programs: dict[str, ExportedProgram] = {} + + for name, program in epm._edge_programs.items(): + pass_result = self._transform_graph_module(program.graph_module) + + if pass_result.modified: + # Create a new exported program. + new_program = ExportedProgram( + root=pass_result.graph_module, + graph=pass_result.graph_module.graph, + graph_signature=_get_updated_graph_signature( + program.graph_signature, pass_result.graph_module + ), + state_dict=program.state_dict, + range_constraints=_get_updated_range_constraints( + pass_result.graph_module + ), + module_call_graph=copy.deepcopy(program._module_call_graph), + example_inputs=program.example_inputs, + constants=program.constants, + verifiers=[program.verifier], + ) + new_program.graph_module.meta.update(program.graph_module.meta) + new_program.graph_module.meta.update(pass_result.graph_module.meta) + + else: + # Keep the old exported program. + new_program = program + + new_programs[name] = new_program + + if len(new_programs) == 0: + # No passes were run, return the old EdgeProgramManager. + return epm + + else: + # Return a new EdgeProgramManager with the updated programs. + return EdgeProgramManager( + new_programs, copy.deepcopy(epm._config_methods), epm.compile_config + ) diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 5820d3c95d3..a426702cbba 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -9,6 +9,9 @@ from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import ( RemoveIOQuantOpsPass, ) +from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import ( + NeutronEdgePassManager, +) from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer @@ -17,8 +20,8 @@ EdgeProgramManager, ExecutorchBackendConfig, ExecutorchProgramManager, - to_edge_transform_and_lower, ) +from executorch.extension.export_util.utils import export_to_edge from torch import nn from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -71,19 +74,22 @@ def to_quantized_edge_program( exir_program_aten.module(), calibration_inputs ) + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_program_manager = export_to_edge( + exir_program_aten__module_quant, + example_input, + edge_compile_config=edge_compile_config, + ) + + edge_program_manager = NeutronEdgePassManager()(edge_program_manager) + compile_spec = generate_neutron_compile_spec( target, operators_not_to_delegate=operators_not_to_delegate, neutron_converter_flavor=neutron_converter_flavor, ) partitioner = NeutronPartitioner(compile_spec) - edge_program_manager = to_edge_transform_and_lower( - torch.export.export( - exir_program_aten__module_quant, example_input, strict=True - ), - partitioner=[partitioner], - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) + edge_program_manager = edge_program_manager.to_backend(partitioner) if remove_quant_io_ops: edge_program_manager = edge_program_manager.transform( diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 3aafab36a95..19a253dccc8 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -142,6 +142,24 @@ def forward(self, x): return x +class ConvFCFCSoftmaxModuleWithoutReshape(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv = torch.nn.Conv2d(4, 5, 2, bias=False) + self.fc1 = torch.nn.Linear(32, 16) + self.fc2 = torch.nn.Linear(16, 8) + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + x = self.conv(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.softmax(x) + + return x + + class ConstantPadNDModule(torch.nn.Module): def __init__(self, paddings: Collection[int], constant: float | int | None = None): super().__init__() diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index c058543be2d..d932bbef6b0 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -15,6 +15,9 @@ AddMMConverter, MMConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.view_copy_converter import ( + ViewCopyConverter, +) from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import OverrideSupportedTargets from torch import nn @@ -203,12 +206,13 @@ def test_batch_norm_linear_fusing__full_pipeline(bias: bool): # But that doesn't affect the validity of this test. with OverrideSupportedTargets(AddMMConverter, new_targets=[]): with OverrideSupportedTargets(MMConverter, new_targets=[]): - edge_program = to_quantized_edge_program( - module, tuple(input_shape) - ).exported_program() - nodes = list(edge_program.graph.nodes) + with OverrideSupportedTargets(ViewCopyConverter, new_targets=[]): + edge_program = to_quantized_edge_program( + module, tuple(input_shape) + ).exported_program() + nodes = list(edge_program.graph.nodes) - assert len(nodes) == 14 + assert len(nodes) == 18 assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ for node in nodes diff --git a/backends/nxp/tests/test_edge_passes.py b/backends/nxp/tests/test_edge_passes.py new file mode 100644 index 00000000000..23515038671 --- /dev/null +++ b/backends/nxp/tests/test_edge_passes.py @@ -0,0 +1,83 @@ +import numpy as np +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import ( + ViewCopyConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + EdgeProgramExecutor, + OverrideSupportedTargets, +) +from executorch.backends.nxp.tests.models import ConvFCFCSoftmaxModuleWithoutReshape +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import Graph, Node + + +def _is_view_copy(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target == exir_ops.edge.aten.view_copy.default + ) + + +def _is_dequantize(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target.__name__ + == "quantized_decomposed.dequantize_per_tensor.default" + ) + + +def _is_quantize(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target.__name__ == "quantized_decomposed.quantize_per_tensor.default" + ) + + +def _find_view_copy_node_indices(graph_nodes: list[Node]) -> list[int]: + view_copy_nodes_indices = [] + + for idx, node in enumerate(graph_nodes): + if _is_view_copy(node): + view_copy_nodes_indices.append(idx) + + return view_copy_nodes_indices + + +def _assert_nodes_form_a_view_copy_qdq_cluster(graph: Graph, node_indices: list[int]): + assert len(node_indices) == 3 + + nodes = list(graph.nodes) + assert _is_dequantize(dequantize := nodes[node_indices[0]]) + assert _is_view_copy(view_copy := nodes[node_indices[1]]) + assert _is_quantize(quantize := nodes[node_indices[2]]) + + # Make sure the nodes are properly connected. + assert view_copy.args[0] == dequantize + assert quantize.args[0] == view_copy + + +def test_moving_view_copy_into_separate_qdq_clusters(): + model = ConvFCFCSoftmaxModuleWithoutReshape() + input_shape = (1, 4, 3, 33) + + # Prohibit `view_copy` conversion for the testing purposes. + with OverrideSupportedTargets(ViewCopyConverter, new_targets=[]): + epm = to_quantized_edge_program(model, input_shape, target="imxrt700") + exported_program = epm.exported_program() + + nodes = list(exported_program.graph_module.graph.nodes) + assert len(nodes) == 28 + + view_copy_indices = _find_view_copy_node_indices(nodes) + + assert len(view_copy_indices) == 4 + for idx in view_copy_indices: + _assert_nodes_form_a_view_copy_qdq_cluster( + exported_program.graph, node_indices=[idx - 1, idx, idx + 1] + ) + + # Make sure the program is runnable. + input_data = np.random.random(input_shape).astype("float32") + program_executor = EdgeProgramExecutor(exported_program) + program_executor.inference(input_data)