Skip to content

Profiler Recipe Update #3435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions recipes_source/recipes/profiler_recipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""
PyTorch Profiler
====================================
**Author:** `Shivam Raikundalia <https://github.com/sraikund16>`_
"""

######################################################################
"""
This recipe explains how to use PyTorch profiler and measure the time and
memory consumption of the model's operators.

Expand All @@ -12,6 +17,10 @@
In this recipe, we will use a simple Resnet model to demonstrate how to
use profiler to analyze model performance.

Prerequisites
---------------
- ``torch >= 1.9``

Setup
-----
To install ``torch`` and ``torchvision`` use the following command:
Expand All @@ -20,10 +29,8 @@

pip install torch torchvision


"""


######################################################################
# Steps
# -----
Expand All @@ -45,7 +52,7 @@

import torch
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity
from torch.profiler import profile, ProfilerActivity, record_function


######################################################################
Expand Down Expand Up @@ -135,7 +142,11 @@
# To get a finer granularity of results and include operator input shapes, pass ``group_by_input_shape=True``
# (note: this requires running the profiler with ``record_shapes=True``):

print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))
print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
)

########################################################################################
# The output might look like this (omitting some columns):
Expand Down Expand Up @@ -167,14 +178,17 @@
# Users could switch between cpu, cuda and xpu
activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
device = 'cuda'
device = "cuda"
activities += [ProfilerActivity.CUDA]
elif torch.xpu.is_available():
device = 'xpu'
device = "xpu"
activities += [ProfilerActivity.XPU]
else:
print('Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices')
print(
"Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices"
)
import sys

sys.exit(0)

sort_by_keyword = device + "_time_total"
Expand Down Expand Up @@ -256,8 +270,9 @@
model = models.resnet18()
inputs = torch.randn(5, 3, 224, 224)

with profile(activities=[ProfilerActivity.CPU],
profile_memory=True, record_shapes=True) as prof:
with profile(
activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True
) as prof:
model(inputs)

print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
Expand Down Expand Up @@ -312,14 +327,17 @@
# Users could switch between cpu, cuda and xpu
activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
device = 'cuda'
device = "cuda"
activities += [ProfilerActivity.CUDA]
elif torch.xpu.is_available():
device = 'xpu'
device = "xpu"
activities += [ProfilerActivity.XPU]
else:
print('Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices')
print(
"Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices"
)
import sys

sys.exit(0)

model = models.resnet18().to(device)
Expand Down Expand Up @@ -347,6 +365,7 @@
with profile(
activities=activities,
with_stack=True,
experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True),
) as prof:
model(inputs)

Expand Down Expand Up @@ -401,12 +420,7 @@

from torch.profiler import schedule

my_schedule = schedule(
skip_first=10,
wait=5,
warmup=1,
active=3,
repeat=2)
my_schedule = schedule(skip_first=10, wait=5, warmup=1, active=3, repeat=2)

######################################################################
# Profiler assumes that the long-running job is composed of steps, numbered
Expand Down Expand Up @@ -444,18 +458,17 @@

sort_by_keyword = "self_" + device + "_time_total"


def trace_handler(p):
output = p.key_averages().table(sort_by=sort_by_keyword, row_limit=10)
print(output)
p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")


with profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=2),
on_trace_ready=trace_handler
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
on_trace_ready=trace_handler,
) as p:
for idx in range(8):
model(inputs)
Expand Down