Skip to content

Task T228334710 update tuning guide #3433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 59 additions & 63 deletions recipes_source/recipes/tuning_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,38 @@
techniques often can be implemented by changing only a few lines of code and can
be applied to a wide range of deep learning models across all domains.

.. grid:: 2

.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites

* General optimization techniques for PyTorch models
* CPU-specific performance optimizations
* GPU acceleration strategies
* Distributed training optimizations

.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites

* PyTorch 2.0 or later
* Python 3.8 or later
* CUDA-capable GPU (recommended for GPU optimizations)
* Linux, macOS, or Windows operating system

Overview
--------

Performance optimization is crucial for efficient deep learning model training and inference.
This tutorial covers a comprehensive set of techniques to accelerate PyTorch workloads across
different hardware configurations and use cases.

General optimizations
---------------------
"""

import torch
import torchvision

###############################################################################
# Enable asynchronous data loading and augmentation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -90,7 +118,7 @@
# setting it to zero, for more details refer to the
# `documentation <https://pytorch.org/docs/master/optim.html#torch.optim.Optimizer.zero_grad>`_.
#
# Alternatively, starting from PyTorch 1.7, call ``model`` or
# Alternatively, call ``model`` or
# ``optimizer.zero_grad(set_to_none=True)``.

###############################################################################
Expand Down Expand Up @@ -129,7 +157,7 @@ def gelu(x):
###############################################################################
# Enable channels_last memory format for computer vision models
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# PyTorch 1.5 introduced support for ``channels_last`` memory format for
# PyTorch supports ``channels_last`` memory format for
# convolutional networks. This format is meant to be used in conjunction with
# `AMP <https://pytorch.org/docs/stable/amp.html>`_ to further accelerate
# convolutional neural networks with
Expand Down Expand Up @@ -250,65 +278,6 @@ def gelu(x):
#
# export LD_PRELOAD=<jemalloc.so/tcmalloc.so>:$LD_PRELOAD

###############################################################################
# Use oneDNN Graph with TorchScript for inference
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# oneDNN Graph can significantly boost inference performance. It fuses some compute-intensive operations such as convolution, matmul with their neighbor operations.
# In PyTorch 2.0, it is supported as a beta feature for ``Float32`` & ``BFloat16`` data-types.
# oneDNN Graph receives the model’s graph and identifies candidates for operator-fusion with respect to the shape of the example input.
# A model should be JIT-traced using an example input.
# Speed-up would then be observed after a couple of warm-up iterations for inputs with the same shape as the example input.
# The example code-snippets below are for resnet50, but they can very well be extended to use oneDNN Graph with custom models as well.

# Only this extra line of code is required to use oneDNN Graph
torch.jit.enable_onednn_fusion(True)

###############################################################################
# Using the oneDNN Graph API requires just one extra line of code for inference with Float32.
# If you are using oneDNN Graph, please avoid calling ``torch.jit.optimize_for_inference``.

# sample input should be of the same shape as expected inputs
sample_input = [torch.rand(32, 3, 224, 224)]
# Using resnet50 from torchvision in this example for illustrative purposes,
# but the line below can indeed be modified to use custom models as well.
model = getattr(torchvision.models, "resnet50")().eval()
# Tracing the model with example input
traced_model = torch.jit.trace(model, sample_input)
# Invoking torch.jit.freeze
traced_model = torch.jit.freeze(traced_model)

###############################################################################
# Once a model is JIT-traced with a sample input, it can then be used for inference after a couple of warm-up runs.

with torch.no_grad():
# a couple of warm-up runs
traced_model(*sample_input)
traced_model(*sample_input)
# speedup would be observed after warm-up runs
traced_model(*sample_input)

###############################################################################
# While the JIT fuser for oneDNN Graph also supports inference with ``BFloat16`` datatype,
# performance benefit with oneDNN Graph is only exhibited by machines with AVX512_BF16
# instruction set architecture (ISA).
# The following code snippets serves as an example of using ``BFloat16`` datatype for inference with oneDNN Graph:

# AMP for JIT mode is enabled by default, and is divergent with its eager mode counterpart
torch._C._jit_set_autocast_mode(False)

with torch.no_grad(), torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):
# Conv-BatchNorm folding for CNN-based Vision Models should be done with ``torch.fx.experimental.optimization.fuse`` when AMP is used
import torch.fx.experimental.optimization as optimization
# Please note that optimization.fuse need not be called when AMP is not used
model = optimization.fuse(model)
model = torch.jit.trace(model, (example_input))
model = torch.jit.freeze(model)
# a couple of warm-up runs
model(example_input)
model(example_input)
# speedup would be observed in subsequent runs.
model(example_input)


###############################################################################
# Train a model on CPU with PyTorch ``DistributedDataParallel``(DDP) functionality
Expand Down Expand Up @@ -426,9 +395,8 @@ def gelu(x):
# * enable AMP
#
# * Introduction to Mixed Precision Training and AMP:
# `video <https://www.youtube.com/watch?v=jF4-_ZK_tyc&feature=youtu.be>`_,
# `slides <https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf>`_
# * native PyTorch AMP is available starting from PyTorch 1.6:
# * native PyTorch AMP is available:
# `documentation <https://pytorch.org/docs/stable/amp.html>`_,
# `examples <https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples>`_,
# `tutorial <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_
Expand Down Expand Up @@ -536,3 +504,31 @@ def gelu(x):
# approximately constant number of tokens (and variable number of sequences in a
# batch), other models solve imbalance by bucketing samples with similar
# sequence length or even by sorting dataset by sequence length.

###############################################################################
# Conclusion
# ----------
#
# This tutorial covered a comprehensive set of performance optimization techniques
# for PyTorch models. The key takeaways include:
#
# * **General optimizations**: Enable async data loading, disable gradients for
# inference, fuse operations with ``torch.compile``, and use efficient memory formats
# * **CPU optimizations**: Leverage NUMA controls, optimize OpenMP settings, and
# use efficient memory allocators
# * **GPU optimizations**: Enable Tensor cores, use CUDA graphs, enable cuDNN
# autotuner, and implement mixed precision training
# * **Distributed optimizations**: Use DistributedDataParallel, optimize gradient
# synchronization, and balance workloads across devices
#
# Many of these optimizations can be applied with minimal code changes and provide
# significant performance improvements across a wide range of deep learning models.
#
# Further Reading
# ---------------
#
# * `PyTorch Performance Tuning Documentation <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_
# * `CUDA Best Practices <https://pytorch.org/docs/stable/notes/cuda.html>`_
# * `Distributed Training Documentation <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_
# * `Mixed Precision Training <https://pytorch.org/docs/stable/amp.html>`_
# * `torch.compile Tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_