diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index 6af372c65bd..61429e78840 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -23,7 +23,7 @@ "beginner_source/examples_autograd/polynomial_autograd", "beginner_source/examples_autograd/polynomial_custom_function", "intermediate_source/mnist_train_nas", # used by ax_multiobjective_nas_tutorial.py - "intermediate_source/fx_conv_bn_fuser", + "intermediate_source/torch_compile_conv_bn_fuser", "intermediate_source/_torch_export_nightly_tutorial", # does not work on release "advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker "prototype_source/fx_graph_mode_ptq_dynamic", diff --git a/index.rst b/index.rst index 82f435d7db7..99ed99b2ea2 100644 --- a/index.rst +++ b/index.rst @@ -348,13 +348,6 @@ Welcome to PyTorch Tutorials .. Code Transformations with FX -.. customcarditem:: - :header: Building a Convolution/Batch Norm fuser in FX - :card_description: Build a simple FX pass that fuses batch norm into convolution to improve performance during inference. - :image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png - :link: intermediate/fx_conv_bn_fuser.html - :tags: FX - .. customcarditem:: :header: Building a Simple Performance Profiler with FX :card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics @@ -583,6 +576,13 @@ Welcome to PyTorch Tutorials :link: intermediate/torch_compile_tutorial.html :tags: Model-Optimization +.. customcarditem:: + :header: Building a Convolution/Batch Norm fuser in torch.compile + :card_description: Build a simple pattern matcher pass that fuses batch norm into convolution to improve performance during inference. + :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: intermediate/torch_compile_conv_bn_fuser.html + :tags: Model-Optimization + .. customcarditem:: :header: Inductor CPU Backend Debugging and Profiling :card_description: Learn the usage, debugging and performance profiling for ``torch.compile`` with Inductor CPU backend. @@ -950,7 +950,6 @@ Additional Resources :hidden: :caption: Code Transforms with FX - intermediate/fx_conv_bn_fuser intermediate/fx_profiling_tutorial .. toctree:: @@ -1001,6 +1000,7 @@ Additional Resources intermediate/nvfuser_intro_tutorial intermediate/ax_multiobjective_nas_tutorial intermediate/torch_compile_tutorial + intermediate/torch_compile_conv_bn_fuser intermediate/compiled_autograd_tutorial intermediate/inductor_debug_cpu intermediate/scaled_dot_product_attention_tutorial diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py deleted file mode 100644 index 547f93fb7f1..00000000000 --- a/intermediate_source/fx_conv_bn_fuser.py +++ /dev/null @@ -1,262 +0,0 @@ -# -*- coding: utf-8 -*- -""" -(beta) Building a Convolution/Batch Norm fuser in FX -******************************************************* -**Author**: `Horace He `_ - -In this tutorial, we are going to use FX, a toolkit for composable function -transformations of PyTorch, to do the following: - -1) Find patterns of conv/batch norm in the data dependencies. -2) For the patterns found in 1), fold the batch norm statistics into the convolution weights. - -Note that this optimization only works for models in inference mode (i.e. `mode.eval()`) - -We will be building the fuser that exists here: -https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py - -""" - - -###################################################################### -# First, let's get some imports out of the way (we will be using all -# of these later in the code). - -from typing import Type, Dict, Any, Tuple, Iterable -import copy -import torch.fx as fx -import torch -import torch.nn as nn - -###################################################################### -# For this tutorial, we are going to create a model consisting of convolutions -# and batch norms. Note that this model has some tricky components - some of -# the conv/batch norm patterns are hidden within Sequentials and one of the -# ``BatchNorms`` is wrapped in another Module. - -class WrappedBatchNorm(nn.Module): - def __init__(self): - super().__init__() - self.mod = nn.BatchNorm2d(1) - def forward(self, x): - return self.mod(x) - -class M(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 1, 1) - self.bn1 = nn.BatchNorm2d(1) - self.conv2 = nn.Conv2d(1, 1, 1) - self.nested = nn.Sequential( - nn.BatchNorm2d(1), - nn.Conv2d(1, 1, 1), - ) - self.wrapped = WrappedBatchNorm() - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.nested(x) - x = self.wrapped(x) - return x - -model = M() - -model.eval() - -###################################################################### -# Fusing Convolution with Batch Norm -# ----------------------------------------- -# One of the primary challenges with trying to automatically fuse convolution -# and batch norm in PyTorch is that PyTorch does not provide an easy way of -# accessing the computational graph. FX resolves this problem by symbolically -# tracing the actual operations called, so that we can track the computations -# through the `forward` call, nested within Sequential modules, or wrapped in -# an user-defined module. - -traced_model = torch.fx.symbolic_trace(model) -print(traced_model.graph) - -###################################################################### -# This gives us a graph representation of our model. Note that both the modules -# hidden within the sequential as well as the wrapped Module have been inlined -# into the graph. This is the default level of abstraction, but it can be -# configured by the pass writer. More information can be found at the FX -# overview https://pytorch.org/docs/master/fx.html#module-torch.fx - - -#################################### -# Fusing Convolution with Batch Norm -# ---------------------------------- -# Unlike some other fusions, fusion of convolution with batch norm does not -# require any new operators. Instead, as batch norm during inference -# consists of a pointwise add and multiply, these operations can be "baked" -# into the preceding convolution's weights. This allows us to remove the batch -# norm entirely from our model! Read -# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The -# code here is copied from -# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py -# clarity purposes. -def fuse_conv_bn_eval(conv, bn): - """ - Given a conv Module `A` and an batch_norm module `B`, returns a conv - module `C` such that C(x) == B(A(x)) in inference mode. - """ - assert(not (conv.training or bn.training)), "Fusion only for eval!" - fused_conv = copy.deepcopy(conv) - - fused_conv.weight, fused_conv.bias = \ - fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, - bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) - - return fused_conv - -def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): - if conv_b is None: - conv_b = torch.zeros_like(bn_rm) - if bn_w is None: - bn_w = torch.ones_like(bn_rm) - if bn_b is None: - bn_b = torch.zeros_like(bn_rm) - bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) - - conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) - conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b - - return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) - - -#################################### -# FX Fusion Pass -# ---------------------------------- -# Now that we have our computational graph as well as a method for fusing -# convolution and batch norm, all that remains is to iterate over the FX graph -# and apply the desired fusions. - - -def _parent_name(target : str) -> Tuple[str, str]: - """ - Splits a ``qualname`` into parent path and last atom. - For example, `foo.bar.baz` -> (`foo.bar`, `baz`) - """ - *parent, name = target.rsplit('.', 1) - return parent[0] if parent else '', name - -def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): - assert(isinstance(node.target, str)) - parent_name, name = _parent_name(node.target) - setattr(modules[parent_name], name, new_module) - - -def fuse(model: torch.nn.Module) -> torch.nn.Module: - model = copy.deepcopy(model) - # The first step of most FX passes is to symbolically trace our model to - # obtain a `GraphModule`. This is a representation of our original model - # that is functionally identical to our original model, except that we now - # also have a graph representation of our forward pass. - fx_model: fx.GraphModule = fx.symbolic_trace(model) - modules = dict(fx_model.named_modules()) - - # The primary representation for working with FX are the `Graph` and the - # `Node`. Each `GraphModule` has a `Graph` associated with it - this - # `Graph` is also what generates `GraphModule.code`. - # The `Graph` itself is represented as a list of `Node` objects. Thus, to - # iterate through all of the operations in our graph, we iterate over each - # `Node` in our `Graph`. - for node in fx_model.graph.nodes: - # The FX IR contains several types of nodes, which generally represent - # call sites to modules, functions, or methods. The type of node is - # determined by `Node.op`. - if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it. - continue - # For call sites, `Node.target` represents the module/function/method - # that's being called. Here, we check `Node.target` to see if it's a - # batch norm module, and then check `Node.args[0].target` to see if the - # input `Node` is a convolution. - if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d: - if len(node.args[0].users) > 1: # Output of conv is used by other nodes - continue - conv = modules[node.args[0].target] - bn = modules[node.target] - fused_conv = fuse_conv_bn_eval(conv, bn) - replace_node_module(node.args[0], modules, fused_conv) - # As we've folded the batch nor into the conv, we need to replace all uses - # of the batch norm with the conv. - node.replace_all_uses_with(node.args[0]) - # Now that all uses of the batch norm have been replaced, we can - # safely remove the batch norm. - fx_model.graph.erase_node(node) - fx_model.graph.lint() - # After we've modified our graph, we need to recompile our graph in order - # to keep the generated code in sync. - fx_model.recompile() - return fx_model - - -###################################################################### -# .. note:: -# We make some simplifications here for demonstration purposes, such as only -# matching 2D convolutions. View -# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py -# for a more usable pass. - -###################################################################### -# Testing out our Fusion Pass -# ----------------------------------------- -# We can now run this fusion pass on our initial toy model and verify that our -# results are identical. In addition, we can print out the code for our fused -# model and verify that there are no more batch norms. - - -fused_model = fuse(model) -print(fused_model.code) -inp = torch.randn(5, 1, 1, 1) -torch.testing.assert_allclose(fused_model(inp), model(inp)) - - -###################################################################### -# Benchmarking our Fusion on ResNet18 -# ----------------------------------- -# We can test our fusion pass on a larger model like ResNet18 and see how much -# this pass improves inference performance. -import torchvision.models as models -import time - -rn18 = models.resnet18() -rn18.eval() - -inp = torch.randn(10, 3, 224, 224) -output = rn18(inp) - -def benchmark(model, iters=20): - for _ in range(10): - model(inp) - begin = time.time() - for _ in range(iters): - model(inp) - return str(time.time()-begin) - -fused_rn18 = fuse(rn18) -print("Unfused time: ", benchmark(rn18)) -print("Fused time: ", benchmark(fused_rn18)) -###################################################################### -# As we previously saw, the output of our FX transformation is -# ("torchscriptable") PyTorch code, we can easily ``jit.script`` the output to try -# and increase our performance even more. In this way, our FX model -# transformation composes with TorchScript with no issues. -jit_rn18 = torch.jit.script(fused_rn18) -print("jit time: ", benchmark(jit_rn18)) - - -############ -# Conclusion -# ---------- -# As we can see, using FX we can easily write static graph transformations on -# PyTorch code. -# -# Since FX is still in beta, we would be happy to hear any -# feedback you have about using it. Please feel free to use the -# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker -# (https://github.com/pytorch/pytorch/issues) to provide any feedback -# you might have. diff --git a/intermediate_source/fx_profiling_tutorial.py b/intermediate_source/fx_profiling_tutorial.py index 8caaf7be39b..7f31338d002 100644 --- a/intermediate_source/fx_profiling_tutorial.py +++ b/intermediate_source/fx_profiling_tutorial.py @@ -216,9 +216,6 @@ def summary(self, should_sort : bool = False) -> str: # # * ``MaxPool2d`` takes up the most time. This is a known issue: # https://github.com/pytorch/pytorch/issues/51393 -# * BatchNorm2d also takes up significant time. We can continue this -# line of thinking and optimize this in the Conv-BN Fusion with FX -# `tutorial `_. # # # Conclusion diff --git a/intermediate_source/torch_compile_conv_bn_fuser.py b/intermediate_source/torch_compile_conv_bn_fuser.py new file mode 100644 index 00000000000..e057d145499 --- /dev/null +++ b/intermediate_source/torch_compile_conv_bn_fuser.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +""" +Building a Convolution/Batch Norm fuser with torch.compile +=========================================================== + +**Author:** `Horace He `_, `Will Feng `_ + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites + + * How to register custom fusion patterns with torch.compile's pattern matcher + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + :class-card: card-prerequisites + + * PyTorch v2.7.0 + +.. note:: + This optimization only works for models in inference mode (i.e. ``model.eval()``). + However, torch.compile's pattern matching system works for both training and inference. + +""" + + +###################################################################### +# First, let's get some imports out of the way (we will be using all +# of these later in the code). + +from typing import Type, Dict, Any, Tuple, Iterable +import copy +import torch +import torch.nn as nn + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### +# For this tutorial, we are going to create a model consisting of convolutions +# and batch norms. Note that this model has some tricky components - some of +# the conv/batch norm patterns are hidden within Sequentials and one of the +# ``BatchNorms`` is wrapped in another Module. + +class WrappedBatchNorm(nn.Module): + def __init__(self): + super().__init__() + self.mod = nn.BatchNorm2d(1) + def forward(self, x): + return self.mod(x) + +class M(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 1, 1) + self.bn1 = nn.BatchNorm2d(1) + self.conv2 = nn.Conv2d(1, 1, 1) + self.nested = nn.Sequential( + nn.BatchNorm2d(1), + nn.Conv2d(1, 1, 1), + ) + self.wrapped = WrappedBatchNorm() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.nested(x) + x = self.wrapped(x) + return x + +model = M().to(device) +model.eval() + +###################################################################### +# Fusing Convolution with Batch Norm +# ----------------------------------------- +# One of the primary challenges with trying to automatically fuse convolution +# and batch norm in PyTorch is that PyTorch does not provide an easy way of +# accessing the computational graph. torch.compile resolves this problem by +# capturing the computational graph during compilation, allowing us to apply +# pattern-based optimizations across the entire model, including operations +# nested within Sequential modules or wrapped in custom modules. +import torch._inductor.pattern_matcher as pm +from torch._inductor.pattern_matcher import register_replacement + +###################################################################### +# torch.compile will capture a graph representation of our model. During +# compilation, modules hidden within Sequential containers and wrapped +# modules are all inlined into the graph, making them available for +# pattern matching and optimization. + + +#################################### +# Fusing Convolution with Batch Norm +# ---------------------------------- +# Unlike some other fusions, fusion of convolution with batch norm does not +# require any new operators. Instead, as batch norm during inference +# consists of a pointwise add and multiply, these operations can be "baked" +# into the preceding convolution's weights. This allows us to remove the batch +# norm entirely from our model! Read +# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The +# code here is copied from +# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py +# clarity purposes. +def fuse_conv_bn_eval(conv, bn): + """ + Given a conv Module `A` and an batch_norm module `B`, returns a conv + module `C` such that C(x) == B(A(x)) in inference mode. + """ + assert(not (conv.training or bn.training)), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + fused_conv.weight, fused_conv.bias = \ + fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) + + return fused_conv + +def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) + conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b + + return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) + + +#################################### +# Pattern Matching with torch.compile +# ------------------------------------ +# Now that we have our fusion logic, we need to register a pattern that +# torch.compile's pattern matcher will recognize and replace during +# compilation. + +# Define the pattern we want to match: conv2d followed by batch_norm +def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias): + conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias) + bn_out = torch.nn.functional.batch_norm( + conv_out, bn_mean, bn_var, bn_weight, bn_bias, + training=False, eps=1e-5 + ) + return bn_out + +def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias): + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_weight, conv_bias, bn_mean, bn_var, 1e-5, bn_weight, bn_bias + ) + return torch.nn.functional.conv2d(x, fused_weight, fused_bias) + +# Example inputs are needed to trace the pattern functions. +# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement. +# These are used to trace the pattern functions to create the match template. +# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here +# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence +# will be matched regardless of channels, kernel size, or spatial dimensions. +# - x: input tensor (batch_size, channels, height, width) +# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w) +# - conv_bias: (out_channels,) +# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels +example_inputs = [ + torch.randn(1, 1, 4, 4).to(device), # x: input tensor + torch.randn(1, 1, 1, 1).to(device), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel + torch.randn(1).to(device), # conv_bias: 1 output channel + torch.randn(1).to(device), # bn_mean: batch norm running mean + torch.randn(1).to(device), # bn_var: batch norm running variance + torch.randn(1).to(device), # bn_weight: batch norm weight (gamma) + torch.randn(1).to(device), # bn_bias: batch norm bias (beta) +] + +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._inductor import config + +# Create a pattern matcher pass and register our pattern +patterns = PatternMatcherPass() + +register_replacement( + conv_bn_pattern, + conv_bn_replacement, + example_inputs, + pm.fwd_only, + patterns, +) + +# Create a custom pass function that applies our patterns +def conv_bn_fusion_pass(graph): + return patterns.apply(graph) + +# Set our custom pass in the config +config.post_grad_custom_post_pass = conv_bn_fusion_pass + + +###################################################################### +# .. note:: +# We make some simplifications here for demonstration purposes, such as only +# matching 2D convolutions. The pattern matcher in torch.compile +# can handle more complex patterns. + +###################################################################### +# Testing out our Fusion Pass +# ----------------------------------------- +# We can now run this fusion pass on our initial toy model and verify that our +# results are identical. In addition, we can print out the code for our fused +# model and verify that there are no more batch norms. + +from torch._dynamo.utils import counters + +# Clear the counters before compilation +counters.clear() + +# Ensure pattern matcher is enabled +config.pattern_matcher = True + +fused_model = torch.compile(model, backend="inductor") +inp = torch.randn(5, 1, 1, 1).to(device) + +# Run the model to trigger compilation and pattern matching +with torch.no_grad(): + output = fused_model(inp) + expected = model(inp) + torch.testing.assert_close(output, expected) + +# Check how many patterns were matched +assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched" + +# Create a model with different shapes than our example_inputs +test_model_diff_shape = nn.Sequential( + nn.Conv2d(3, 16, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 32, 7), + nn.BatchNorm2d(32), +).to(device).eval() + +counters.clear() +compiled_diff_shape = torch.compile(test_model_diff_shape, backend="inductor") +test_input_diff_shape = torch.randn(1, 3, 28, 28).to(device) +with torch.no_grad(): + compiled_diff_shape(test_input_diff_shape) + +# Check how many patterns were matched +assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched" + + +###################################################################### +# Benchmarking our Fusion on ResNet18 +# ----------------------------------- +# We can test our fusion pass on a larger model like ResNet18 and see how much +# this pass improves inference performance. +import torchvision.models as models +import time + +rn18 = models.resnet18().to(device) +rn18.eval() + +inp = torch.randn(10, 3, 224, 224).to(device) +output = rn18(inp) + +def benchmark(model, iters=20): + with torch.no_grad(): + for _ in range(10): + model(inp) + begin = time.time() + for _ in range(iters): + model(inp) + return str(time.time()-begin) + +# Benchmark original model +print("Original model time: ", benchmark(rn18)) + +# Compile with our custom pattern +compiled_with_pattern_matching = torch.compile(rn18, backend="inductor") + +# Benchmark compiled model +print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching)) + + +############ +# Conclusion +# ---------- +# As we can see, torch.compile provides a powerful way to implement +# graph transformations and optimizations through pattern matching. +# By registering custom patterns, we can extend torch.compile's +# optimization capabilities to handle domain-specific transformations. +# +# The conv-bn fusion demonstrated here is just one example of what's +# possible with torch.compile's pattern matching system. \ No newline at end of file