Skip to content

Update distributed checkpoint recipes #3446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions recipes_source/distributed_async_checkpoint_recipe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ Specifically:
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"

Expand All @@ -74,7 +73,7 @@ Specifically:

def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
Expand Down Expand Up @@ -105,7 +104,7 @@ Specifically:
os.environ["MASTER_PORT"] = "12355 "

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


Expand All @@ -119,7 +118,7 @@ Specifically:

# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
model = fully_shard(model)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Expand Down Expand Up @@ -158,9 +157,9 @@ Specifically, this optimization attacks the main overhead of asynchronous checkp
checkpoint requests users can take advantage of direct memory access to speed up this copy.

.. note::
The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
leading to the same
peak memory pressure being sustained through the application life.

Expand All @@ -175,11 +174,10 @@ checkpoint requests users can take advantage of direct memory access to speed up
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint import StorageWriter
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter

CHECKPOINT_DIR = "checkpoint"

Expand All @@ -199,7 +197,7 @@ checkpoint requests users can take advantage of direct memory access to speed up

def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
Expand Down Expand Up @@ -230,7 +228,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
os.environ["MASTER_PORT"] = "12355 "

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


Expand All @@ -244,7 +242,7 @@ checkpoint requests users can take advantage of direct memory access to speed up

# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
model = fully_shard(model)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Expand All @@ -254,7 +252,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
# into a persistent buffer with pinned memory enabled.
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the
# pinned memory buffer.
writer = StorageWriter(cached_state_dict=True)
writer = StorageWriter(cache_staged_state_dict=True, path=CHECKPOINT_DIR)
checkpoint_future = None
for step in range(10):
optimizer.zero_grad()
Expand Down
11 changes: 5 additions & 6 deletions recipes_source/distributed_checkpoint_recipe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"

Expand Down Expand Up @@ -127,7 +126,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input

# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
model = fully_shard(model)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Expand All @@ -152,7 +151,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
join=True,
)

Please go ahead and check the `checkpoint` directory. You should see 8 checkpoint files as shown below.
Please go ahead and check the `checkpoint` directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files.

.. figure:: /_static/img/distributed/distributed_checkpoint_generated_files.png
:width: 100%
Expand Down Expand Up @@ -183,7 +182,7 @@ The reason that we need the ``state_dict`` prior to loading is:
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import fully_shard

CHECKPOINT_DIR = "checkpoint"

Expand Down Expand Up @@ -248,7 +247,7 @@ The reason that we need the ``state_dict`` prior to loading is:

# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
model = fully_shard(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

Expand Down