From 4b0fedec85c34df1a77129db7934669b21131fa5 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 01:37:31 -0700 Subject: [PATCH 1/5] Create torch_compile_conv_bn_fuser tutorial adapted from fx_conv_bn_fuser --- .jenkins/validate_tutorials_built.py | 2 +- index.rst | 16 +- intermediate_source/fx_conv_bn_fuser.py | 265 ++++++++++--------- intermediate_source/fx_profiling_tutorial.py | 4 +- 4 files changed, 156 insertions(+), 131 deletions(-) diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index 6af372c65bd..61429e78840 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -23,7 +23,7 @@ "beginner_source/examples_autograd/polynomial_autograd", "beginner_source/examples_autograd/polynomial_custom_function", "intermediate_source/mnist_train_nas", # used by ax_multiobjective_nas_tutorial.py - "intermediate_source/fx_conv_bn_fuser", + "intermediate_source/torch_compile_conv_bn_fuser", "intermediate_source/_torch_export_nightly_tutorial", # does not work on release "advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker "prototype_source/fx_graph_mode_ptq_dynamic", diff --git a/index.rst b/index.rst index 82f435d7db7..99ed99b2ea2 100644 --- a/index.rst +++ b/index.rst @@ -348,13 +348,6 @@ Welcome to PyTorch Tutorials .. Code Transformations with FX -.. customcarditem:: - :header: Building a Convolution/Batch Norm fuser in FX - :card_description: Build a simple FX pass that fuses batch norm into convolution to improve performance during inference. - :image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png - :link: intermediate/fx_conv_bn_fuser.html - :tags: FX - .. customcarditem:: :header: Building a Simple Performance Profiler with FX :card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics @@ -583,6 +576,13 @@ Welcome to PyTorch Tutorials :link: intermediate/torch_compile_tutorial.html :tags: Model-Optimization +.. customcarditem:: + :header: Building a Convolution/Batch Norm fuser in torch.compile + :card_description: Build a simple pattern matcher pass that fuses batch norm into convolution to improve performance during inference. + :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: intermediate/torch_compile_conv_bn_fuser.html + :tags: Model-Optimization + .. customcarditem:: :header: Inductor CPU Backend Debugging and Profiling :card_description: Learn the usage, debugging and performance profiling for ``torch.compile`` with Inductor CPU backend. @@ -950,7 +950,6 @@ Additional Resources :hidden: :caption: Code Transforms with FX - intermediate/fx_conv_bn_fuser intermediate/fx_profiling_tutorial .. toctree:: @@ -1001,6 +1000,7 @@ Additional Resources intermediate/nvfuser_intro_tutorial intermediate/ax_multiobjective_nas_tutorial intermediate/torch_compile_tutorial + intermediate/torch_compile_conv_bn_fuser intermediate/compiled_autograd_tutorial intermediate/inductor_debug_cpu intermediate/scaled_dot_product_attention_tutorial diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index 547f93fb7f1..d650e372ca5 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -1,19 +1,20 @@ # -*- coding: utf-8 -*- """ -(beta) Building a Convolution/Batch Norm fuser in FX -******************************************************* -**Author**: `Horace He `_ +Building a Convolution/Batch Norm fuser with torch.compile +****************************************************************** +**Author**: `Horace He `__, `Will Feng `__ -In this tutorial, we are going to use FX, a toolkit for composable function -transformations of PyTorch, to do the following: +In this tutorial, we are going to use torch.compile and its pattern matching +capabilities to do the following: 1) Find patterns of conv/batch norm in the data dependencies. 2) For the patterns found in 1), fold the batch norm statistics into the convolution weights. -Note that this optimization only works for models in inference mode (i.e. `mode.eval()`) +Note that this specific optimization only works for models in inference mode (i.e. `mode.eval()`). +But the pattern matching system in torch.compile works for both training and inference. -We will be building the fuser that exists here: -https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py +We will demonstrate how to register custom fusion patterns with torch.compile's +pattern matcher to optimize model performance. """ @@ -24,10 +25,11 @@ from typing import Type, Dict, Any, Tuple, Iterable import copy -import torch.fx as fx import torch import torch.nn as nn +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ###################################################################### # For this tutorial, we are going to create a model consisting of convolutions # and batch norms. Note that this model has some tricky components - some of @@ -61,8 +63,7 @@ def forward(self, x): x = self.wrapped(x) return x -model = M() - +model = M().to(device) model.eval() ###################################################################### @@ -70,20 +71,18 @@ def forward(self, x): # ----------------------------------------- # One of the primary challenges with trying to automatically fuse convolution # and batch norm in PyTorch is that PyTorch does not provide an easy way of -# accessing the computational graph. FX resolves this problem by symbolically -# tracing the actual operations called, so that we can track the computations -# through the `forward` call, nested within Sequential modules, or wrapped in -# an user-defined module. - -traced_model = torch.fx.symbolic_trace(model) -print(traced_model.graph) +# accessing the computational graph. torch.compile resolves this problem by +# capturing the computational graph during compilation, allowing us to apply +# pattern-based optimizations across the entire model, including operations +# nested within Sequential modules or wrapped in custom modules. +import torch._inductor.pattern_matcher as pm +from torch._inductor.pattern_matcher import register_replacement ###################################################################### -# This gives us a graph representation of our model. Note that both the modules -# hidden within the sequential as well as the wrapped Module have been inlined -# into the graph. This is the default level of abstraction, but it can be -# configured by the pass writer. More information can be found at the FX -# overview https://pytorch.org/docs/master/fx.html#module-torch.fx +# torch.compile will capture a graph representation of our model. During +# compilation, modules hidden within Sequential containers and wrapped +# modules are all inlined into the graph, making them available for +# pattern matching and optimization. #################################### @@ -128,78 +127,74 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): #################################### -# FX Fusion Pass -# ---------------------------------- -# Now that we have our computational graph as well as a method for fusing -# convolution and batch norm, all that remains is to iterate over the FX graph -# and apply the desired fusions. - - -def _parent_name(target : str) -> Tuple[str, str]: - """ - Splits a ``qualname`` into parent path and last atom. - For example, `foo.bar.baz` -> (`foo.bar`, `baz`) - """ - *parent, name = target.rsplit('.', 1) - return parent[0] if parent else '', name - -def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): - assert(isinstance(node.target, str)) - parent_name, name = _parent_name(node.target) - setattr(modules[parent_name], name, new_module) - - -def fuse(model: torch.nn.Module) -> torch.nn.Module: - model = copy.deepcopy(model) - # The first step of most FX passes is to symbolically trace our model to - # obtain a `GraphModule`. This is a representation of our original model - # that is functionally identical to our original model, except that we now - # also have a graph representation of our forward pass. - fx_model: fx.GraphModule = fx.symbolic_trace(model) - modules = dict(fx_model.named_modules()) - - # The primary representation for working with FX are the `Graph` and the - # `Node`. Each `GraphModule` has a `Graph` associated with it - this - # `Graph` is also what generates `GraphModule.code`. - # The `Graph` itself is represented as a list of `Node` objects. Thus, to - # iterate through all of the operations in our graph, we iterate over each - # `Node` in our `Graph`. - for node in fx_model.graph.nodes: - # The FX IR contains several types of nodes, which generally represent - # call sites to modules, functions, or methods. The type of node is - # determined by `Node.op`. - if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it. - continue - # For call sites, `Node.target` represents the module/function/method - # that's being called. Here, we check `Node.target` to see if it's a - # batch norm module, and then check `Node.args[0].target` to see if the - # input `Node` is a convolution. - if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d: - if len(node.args[0].users) > 1: # Output of conv is used by other nodes - continue - conv = modules[node.args[0].target] - bn = modules[node.target] - fused_conv = fuse_conv_bn_eval(conv, bn) - replace_node_module(node.args[0], modules, fused_conv) - # As we've folded the batch nor into the conv, we need to replace all uses - # of the batch norm with the conv. - node.replace_all_uses_with(node.args[0]) - # Now that all uses of the batch norm have been replaced, we can - # safely remove the batch norm. - fx_model.graph.erase_node(node) - fx_model.graph.lint() - # After we've modified our graph, we need to recompile our graph in order - # to keep the generated code in sync. - fx_model.recompile() - return fx_model +# Pattern Matching with torch.compile +# ------------------------------------ +# Now that we have our fusion logic, we need to register a pattern that +# torch.compile's pattern matcher will recognize and replace during +# compilation. + +# Define the pattern we want to match: conv2d followed by batch_norm +def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias): + conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias) + bn_out = torch.nn.functional.batch_norm( + conv_out, bn_mean, bn_var, bn_weight, bn_bias, + training=False, eps=1e-5 + ) + return bn_out + +def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias): + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_weight, conv_bias, bn_mean, bn_var, 1e-5, bn_weight, bn_bias + ) + return torch.nn.functional.conv2d(x, fused_weight, fused_bias) + +# Example inputs are needed to trace the pattern functions. +# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement. +# These are used to trace the pattern functions to create the match template. +# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here +# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence +# will be matched regardless of channels, kernel size, or spatial dimensions. +# - x: input tensor (batch_size, channels, height, width) +# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w) +# - conv_bias: (out_channels,) +# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels +example_inputs = [ + torch.randn(1, 1, 4, 4).to(device), # x: input tensor + torch.randn(1, 1, 1, 1).to(device), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel + torch.randn(1).to(device), # conv_bias: 1 output channel + torch.randn(1).to(device), # bn_mean: batch norm running mean + torch.randn(1).to(device), # bn_var: batch norm running variance + torch.randn(1).to(device), # bn_weight: batch norm weight (gamma) + torch.randn(1).to(device), # bn_bias: batch norm bias (beta) +] + +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._inductor import config + +# Create a pattern matcher pass and register our pattern +patterns = PatternMatcherPass() + +register_replacement( + conv_bn_pattern, + conv_bn_replacement, + example_inputs, + pm.fwd_only, + patterns, +) + +# Create a custom pass function that applies our patterns +def conv_bn_fusion_pass(graph): + return patterns.apply(graph) + +# Set our custom pass in the config +config.post_grad_custom_post_pass = conv_bn_fusion_pass ###################################################################### # .. note:: # We make some simplifications here for demonstration purposes, such as only -# matching 2D convolutions. View -# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py -# for a more usable pass. +# matching 2D convolutions. The pattern matcher in torch.compile +# can handle more complex patterns. ###################################################################### # Testing out our Fusion Pass @@ -208,11 +203,43 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module: # results are identical. In addition, we can print out the code for our fused # model and verify that there are no more batch norms. +from torch._dynamo.utils import counters + +# Clear the counters before compilation +counters.clear() + +# Ensure pattern matcher is enabled +config.pattern_matcher = True -fused_model = fuse(model) -print(fused_model.code) -inp = torch.randn(5, 1, 1, 1) -torch.testing.assert_allclose(fused_model(inp), model(inp)) +fused_model = torch.compile(model, backend="inductor") +inp = torch.randn(5, 1, 1, 1).to(device) + +# Run the model to trigger compilation and pattern matching +with torch.no_grad(): + output = fused_model(inp) + expected = model(inp) + torch.testing.assert_close(output, expected) + +# Check how many patterns were matched +assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched" + +# Create a model with different shapes than our example_inputs +test_model_diff_shape = nn.Sequential( + nn.Conv2d(3, 16, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 32, 7), + nn.BatchNorm2d(32), +).to(device).eval() + +counters.clear() +compiled_diff_shape = torch.compile(test_model_diff_shape, backend="inductor") +test_input_diff_shape = torch.randn(1, 3, 28, 28).to(device) +with torch.no_grad(): + compiled_diff_shape(test_input_diff_shape) + +# Check how many patterns were matched +assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched" ###################################################################### @@ -223,40 +250,38 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module: import torchvision.models as models import time -rn18 = models.resnet18() +rn18 = models.resnet18().to(device) rn18.eval() -inp = torch.randn(10, 3, 224, 224) +inp = torch.randn(10, 3, 224, 224).to(device) output = rn18(inp) def benchmark(model, iters=20): - for _ in range(10): - model(inp) - begin = time.time() - for _ in range(iters): - model(inp) - return str(time.time()-begin) - -fused_rn18 = fuse(rn18) -print("Unfused time: ", benchmark(rn18)) -print("Fused time: ", benchmark(fused_rn18)) -###################################################################### -# As we previously saw, the output of our FX transformation is -# ("torchscriptable") PyTorch code, we can easily ``jit.script`` the output to try -# and increase our performance even more. In this way, our FX model -# transformation composes with TorchScript with no issues. -jit_rn18 = torch.jit.script(fused_rn18) -print("jit time: ", benchmark(jit_rn18)) + with torch.no_grad(): + for _ in range(10): + model(inp) + begin = time.time() + for _ in range(iters): + model(inp) + return str(time.time()-begin) + +# Benchmark original model +print("Original model time: ", benchmark(rn18)) + +# Compile with our custom pattern +compiled_with_pattern_matching = torch.compile(rn18, backend="inductor") + +# Benchmark compiled model +print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching)) ############ # Conclusion # ---------- -# As we can see, using FX we can easily write static graph transformations on -# PyTorch code. +# As we can see, torch.compile provides a powerful way to implement +# graph transformations and optimizations through pattern matching. +# By registering custom patterns, we can extend torch.compile's +# optimization capabilities to handle domain-specific transformations. # -# Since FX is still in beta, we would be happy to hear any -# feedback you have about using it. Please feel free to use the -# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker -# (https://github.com/pytorch/pytorch/issues) to provide any feedback -# you might have. +# The conv-bn fusion demonstrated here is just one example of what's +# possible with torch.compile's pattern matching system. \ No newline at end of file diff --git a/intermediate_source/fx_profiling_tutorial.py b/intermediate_source/fx_profiling_tutorial.py index 8caaf7be39b..b4a37b66bc9 100644 --- a/intermediate_source/fx_profiling_tutorial.py +++ b/intermediate_source/fx_profiling_tutorial.py @@ -217,8 +217,8 @@ def summary(self, should_sort : bool = False) -> str: # * ``MaxPool2d`` takes up the most time. This is a known issue: # https://github.com/pytorch/pytorch/issues/51393 # * BatchNorm2d also takes up significant time. We can continue this -# line of thinking and optimize this in the Conv-BN Fusion with FX -# `tutorial `_. +# line of thinking and optimize this in the Conv-BN Fusion with torch.compile +# `tutorial `_. # # # Conclusion From d4dddaa5707a0417f8c907764aa8cf2f45fddf04 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 11:19:00 -0700 Subject: [PATCH 2/5] update --- intermediate_source/fx_conv_bn_fuser.py | 25 ++++++++++++-------- intermediate_source/fx_profiling_tutorial.py | 3 --- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index d650e372ca5..e057d145499 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -1,20 +1,25 @@ # -*- coding: utf-8 -*- """ Building a Convolution/Batch Norm fuser with torch.compile -****************************************************************** -**Author**: `Horace He `__, `Will Feng `__ +=========================================================== -In this tutorial, we are going to use torch.compile and its pattern matching -capabilities to do the following: +**Author:** `Horace He `_, `Will Feng `_ -1) Find patterns of conv/batch norm in the data dependencies. -2) For the patterns found in 1), fold the batch norm statistics into the convolution weights. +.. grid:: 2 -Note that this specific optimization only works for models in inference mode (i.e. `mode.eval()`). -But the pattern matching system in torch.compile works for both training and inference. + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites -We will demonstrate how to register custom fusion patterns with torch.compile's -pattern matcher to optimize model performance. + * How to register custom fusion patterns with torch.compile's pattern matcher + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + :class-card: card-prerequisites + + * PyTorch v2.7.0 + +.. note:: + This optimization only works for models in inference mode (i.e. ``model.eval()``). + However, torch.compile's pattern matching system works for both training and inference. """ diff --git a/intermediate_source/fx_profiling_tutorial.py b/intermediate_source/fx_profiling_tutorial.py index b4a37b66bc9..7f31338d002 100644 --- a/intermediate_source/fx_profiling_tutorial.py +++ b/intermediate_source/fx_profiling_tutorial.py @@ -216,9 +216,6 @@ def summary(self, should_sort : bool = False) -> str: # # * ``MaxPool2d`` takes up the most time. This is a known issue: # https://github.com/pytorch/pytorch/issues/51393 -# * BatchNorm2d also takes up significant time. We can continue this -# line of thinking and optimize this in the Conv-BN Fusion with torch.compile -# `tutorial `_. # # # Conclusion From 35bcd92691b6ddd4cf65f94f9fef3c7eccdef8ff Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 11:29:55 -0700 Subject: [PATCH 3/5] up --- .../{fx_conv_bn_fuser.py => torch_compile_conv_bn_fuser.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename intermediate_source/{fx_conv_bn_fuser.py => torch_compile_conv_bn_fuser.py} (100%) diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/torch_compile_conv_bn_fuser.py similarity index 100% rename from intermediate_source/fx_conv_bn_fuser.py rename to intermediate_source/torch_compile_conv_bn_fuser.py From d4388d93565a494d2413364fb90723b762c459c0 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 11:32:04 -0700 Subject: [PATCH 4/5] up --- intermediate_source/fx_profiling_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/fx_profiling_tutorial.py b/intermediate_source/fx_profiling_tutorial.py index 7f31338d002..fd7f888eb30 100644 --- a/intermediate_source/fx_profiling_tutorial.py +++ b/intermediate_source/fx_profiling_tutorial.py @@ -230,4 +230,4 @@ def summary(self, should_sort : bool = False) -> str: # feedback you have about using it. Please feel free to use the # PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker # (https://github.com/pytorch/pytorch/issues) to provide any feedback -# you might have. +# you might have. \ No newline at end of file From d5aeb88a7f7d97931148cb4172dec0f70831c820 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 11:32:26 -0700 Subject: [PATCH 5/5] up --- intermediate_source/fx_profiling_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/fx_profiling_tutorial.py b/intermediate_source/fx_profiling_tutorial.py index fd7f888eb30..7f31338d002 100644 --- a/intermediate_source/fx_profiling_tutorial.py +++ b/intermediate_source/fx_profiling_tutorial.py @@ -230,4 +230,4 @@ def summary(self, should_sort : bool = False) -> str: # feedback you have about using it. Please feel free to use the # PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker # (https://github.com/pytorch/pytorch/issues) to provide any feedback -# you might have. \ No newline at end of file +# you might have.