Skip to content

NXP Backend: Add infrastructure for pre processing passes in edge dialect #13183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node
from torch.fx.passes.infra.pass_base import PassResult


def insert_qdq_pair_after_node(
graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple
):
# Insert a Quantize node.
with graph.inserting_after(anchor):
quantize_op = graph.create_node(
op="call_function",
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # Will be added later.
)
quantize_op.meta = anchor.meta

# Insert a Dequantize node.
with graph.inserting_after(quantize_op):
dequantize_op = graph.create_node(
op="call_function",
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(quantize_op,) + q_params,
)
dequantize_op.meta = quantize_op.meta
anchor.replace_all_uses_with(dequantize_op)

# Add this at the end, so the `anchor.replace_all_uses_with(dequantize_op)` does not replace the first use of the
# `quantize_op`.
quantize_op.args = (anchor,) + q_params


def _is_dequantize(node_: Node) -> bool:
return (
node_.op == "call_function"
and node_.target
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
)


def _is_quantize(node_: Node) -> bool:
return (
node_.op == "call_function"
and node_.target
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
)


class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
"""
┌─────▼──────┐
│ │ dequantize │
┌─────▼──────┐ └─────┬──────┘
│ dequantize │ ┌─────▼──────┐
└─────┬──────┘ │ <aux_node> │
┌─────▼──────┐ └─────┬──────┘
│ <aux_node> │ ┌────▼─────┐ ┐
└─────┬──────┘ │ quantize │ │
┌──────────▼──────────┐ replaced with └────┬─────┘ │
⋯┤ <main_cluster_node> ├⋯ ──────────────► │ │ newly added nodes
└──────────┬──────────┘ ┌─────▼──────┐ │
▼ │ dequantize │ │
⋮ └─────┬──────┘ ┘
┌────▼─────┐ ┌──────────▼──────────┐
│ quantize │ ⋯┤ <main_cluster_node> ├⋯
└────┬─────┘ └──────────┬──────────┘
▼ ▼
┌────▼─────┐
│ quantize │
└────┬─────┘
"""

allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]

# List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification.
allowed_main_cluster_nodes = [
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.mm.default,
]

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
for aux_node in graph_module.graph.nodes:
if (
aux_node.op != "call_function"
or aux_node.target not in self.allowed_auxiliary_nodes
):
continue

dequantize_node = aux_node.args[0]
if not _is_dequantize(dequantize_node):
# Not the intended use case.
continue

users = list(aux_node.users.keys())
if len(users) != 1:
# Not the intended use case.
continue

main_cluster_node = users[0]
if (
main_cluster_node.op != "call_function"
or main_cluster_node.target not in self.allowed_main_cluster_nodes
):
# Unsupported `main_cluster_node`.
continue

# Make sure the nodes are part of the same QDQ cluster.
cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node)
if any(
node_ not in cluster
for node_ in [dequantize_node, aux_node, main_cluster_node]
):
continue

# ---- The nodes follow the pattern described in the header. ----

q_params = dequantize_node.args[1:]
insert_qdq_pair_after_node(graph_module.graph, aux_node, q_params)

# The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent
# class will call this pass again.
return PassResult(graph_module, True)
Copy link
Collaborator

@robert-kalmar robert-kalmar Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pop-korn, in principal this does not differ from the initial draft. You return everytime, you make a modification and the caller EdgePassManager starts a new iteration. Can you do all the graph modification and return then?

@digantdesai, @Pop-korn noticed in some of the passes the code iterates over a changing graph. E.g. https://github.com/pytorch/executorch/blob/main/backends/xnnpack/_passes/fuse_batch_norm.py#L41 the graph.nodes gets modified in-place in the https://github.com/pytorch/executorch/blob/main/backends/xnnpack/_passes/fuse_batch_norm.py#L208, everytime a match is found.
@Pop-korn, can you please confirm? @digantdesai Is this legit and intended?

Copy link
Contributor Author

@Pop-korn Pop-korn Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responding to the first paragraph:
We can safely to all the modifications at once, and then return. We would be modifying the list of nodes while iterating over it, but it shouldn't have negative side-effects as we are only inserting nodes. The downside is that is is not good practice. Do you think that implementation would be preferred?

The second paragraph:
As far as I can tell, the XNNPack batch_norm fusion pass does indeed modify the graph while iterating over it. When the pass removes a batch_norm node, the next iteration of the for loop will skip a node. If (somehow) the graph contained the sequence convolution -> batch_norm -> batch_norm, only the first batch_norm node would be fused, and the second one would be skipped.


# Nothing was changed.
return PassResult(graph_module, False)


class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
"""
┌─────▼──────┐
│ │ dequantize │
┌─────▼──────┐ └─────┬──────┘
│ dequantize │ ⋮
└─────┬──────┘ ┌──────────▼──────────┐
▼ ⋯┤ <main_cluster_node> ├⋯
⋮ └──────────┬──────────┘
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
⋯┤ <main_cluster_node> ├⋯ ──────────────► │ quantize │ │
└──────────┬──────────┘ └────┬─────┘ │
┌─────▼──────┐ │ │ newly added nodes
│ <aux_node> │ ┌─────▼──────┐ │
└─────┬──────┘ │ dequantize │ │
┌────▼─────┐ └─────┬──────┘ ┘
│ quantize │ ┌─────▼──────┐
└────┬─────┘ │ <aux_node> │
▼ └─────┬──────┘
┌────▼─────┐
│ quantize │
└────┬─────┘
"""

allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]

# List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification.
allowed_main_cluster_nodes = [
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.mm.default,
]

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:

for aux_node in graph_module.graph.nodes:
if (
aux_node.op != "call_function"
or aux_node.target not in self.allowed_auxiliary_nodes
):
continue

main_cluster_node = aux_node.args[0]
if (
main_cluster_node.op != "call_function"
or main_cluster_node.target not in self.allowed_main_cluster_nodes
):
# Unsupported `main_cluster_node`.
continue

users = list(aux_node.users.keys())
if len(users) != 1:
# Not the intended use case.
continue

quantize_node = users[0]
if not _is_quantize(quantize_node):
# Not the intended use case.
continue

# Make sure the nodes are part of the same QDQ cluster.
cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node)
if any(
node_ not in cluster
for node_ in [quantize_node, aux_node, main_cluster_node]
):
continue

# ---- The nodes follow the pattern described in the header. ----

q_params = quantize_node.args[1:]
insert_qdq_pair_after_node(graph_module.graph, main_cluster_node, q_params)

# The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent
# class will call this pass again.
return PassResult(graph_module, True)

# Nothing was changed.
return PassResult(graph_module, False)
55 changes: 55 additions & 0 deletions backends/nxp/edge_passes/neutron_edge_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from abc import abstractmethod

import torch

from executorch.exir.pass_base import ExportPass
from torch.fx.passes.infra.pass_base import PassResult


class NeutronEdgePass(ExportPass):
"""Abstract parent class for pre-processing passes on the edge dialect level."""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
"""Call `self.run()` as long as changes are being made. After a pass modifies the graph, it cannot keep on
iterating through its nodes, and must return. This method allows the pass to go through the whole model.
"""

# Every pass will return once it makes a change to the graph, to avoid traversing and modifying a graph at the
# same time. Therefore, it must be called multiple times (at most `iteration_limit` times).
iteration_limit = len(graph_module.graph.nodes)
modified = False
for _ in range(iteration_limit):
res = self.run(graph_module)
if res.modified:
modified = True
graph_module = res.graph_module

else:
# No more changes have been made.
graph_module = self.recompile_module(graph_module)
return PassResult(graph_module, modified)

# Iteration limit was reached.
logging.warning(
f"The NeutronEdgePass `{self.__class__.__name__}` reached the iteration limit."
)
graph_module = self.recompile_module(graph_module)
return PassResult(graph_module, modified)

@abstractmethod
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
"""Child classes should implement their graph modification here."""
pass

def recompile_module(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""Recompile the graph and re-trace the metadata. This should ensure that the datatypes and shapes are correct."""
graph_module.recompile()
return super().call(graph_module).graph_module
89 changes: 89 additions & 0 deletions backends/nxp/edge_passes/neutron_edge_pass_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import (
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass,
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
)
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
from executorch.exir import EdgeProgramManager
from executorch.exir.program._program import (
_get_updated_graph_signature,
_get_updated_range_constraints,
)

from torch import nn
from torch.export import ExportedProgram
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager


class NeutronEdgePassManager(PassManager):

def __init__(self, passes: list[NeutronEdgePass] = None):
passes: list[NeutronEdgePass] = passes or [
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
]

super().__init__(
passes,
steps=10, # Empirical value. At most 10 cycles of passes will be run.
)

def _transform_graph_module(self, module: nn.Module) -> PassResult:
"""Apply the passes to a single graph module."""
pass_result: PassResult = super().__call__(module)

graph_module = pass_result.graph_module
graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return pass_result

def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:
"""Apply the passes to all graph modules in the edge program."""
new_programs: dict[str, ExportedProgram] = {}

for name, program in epm._edge_programs.items():
pass_result = self._transform_graph_module(program.graph_module)

if pass_result.modified:
# Create a new exported program.
new_program = ExportedProgram(
root=pass_result.graph_module,
graph=pass_result.graph_module.graph,
graph_signature=_get_updated_graph_signature(
program.graph_signature, pass_result.graph_module
),
state_dict=program.state_dict,
range_constraints=_get_updated_range_constraints(
pass_result.graph_module
),
module_call_graph=copy.deepcopy(program._module_call_graph),
example_inputs=program.example_inputs,
constants=program.constants,
verifiers=[program.verifier],
)
new_program.graph_module.meta.update(program.graph_module.meta)
new_program.graph_module.meta.update(pass_result.graph_module.meta)

else:
# Keep the old exported program.
new_program = program

new_programs[name] = new_program

if len(new_programs) == 0:
# No passes were run, return the old EdgeProgramManager.
return epm

else:
# Return a new EdgeProgramManager with the updated programs.
return EdgeProgramManager(
new_programs, copy.deepcopy(epm._config_methods), epm.compile_config
)
Loading
Loading