Skip to content

Commit 8de775a

Browse files
Saiteja64saiteja64svekars
authored
Update distributed checkpoint recipes (#3446)
* update distributed checkpoint recipes * update text to be more clear in distributed_checkpoint_recipe.rst --------- Co-authored-by: saiteja64 <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent ab48a0c commit 8de775a

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

recipes_source/distributed_async_checkpoint_recipe.rst

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,9 @@ Specifically:
5151
import torch.multiprocessing as mp
5252
import torch.nn as nn
5353
54-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
54+
from torch.distributed.fsdp import fully_shard
5555
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
5656
from torch.distributed.checkpoint.stateful import Stateful
57-
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
5857
5958
CHECKPOINT_DIR = "checkpoint"
6059
@@ -74,7 +73,7 @@ Specifically:
7473
7574
def state_dict(self):
7675
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
77-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
76+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
7877
return {
7978
"model": model_state_dict,
8079
"optim": optimizer_state_dict
@@ -105,7 +104,7 @@ Specifically:
105104
os.environ["MASTER_PORT"] = "12355 "
106105
107106
# initialize the process group
108-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
107+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
109108
torch.cuda.set_device(rank)
110109
111110
@@ -119,7 +118,7 @@ Specifically:
119118
120119
# create a model and move it to GPU with id rank
121120
model = ToyModel().to(rank)
122-
model = FSDP(model)
121+
model = fully_shard(model)
123122
124123
loss_fn = nn.MSELoss()
125124
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
@@ -158,9 +157,9 @@ Specifically, this optimization attacks the main overhead of asynchronous checkp
158157
checkpoint requests users can take advantage of direct memory access to speed up this copy.
159158

160159
.. note::
161-
The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
162-
the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
163-
checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
160+
The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
161+
the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
162+
checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
164163
leading to the same
165164
peak memory pressure being sustained through the application life.
166165

@@ -175,11 +174,10 @@ checkpoint requests users can take advantage of direct memory access to speed up
175174
import torch.multiprocessing as mp
176175
import torch.nn as nn
177176
178-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
177+
from torch.distributed.fsdp import fully_shard
179178
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
180179
from torch.distributed.checkpoint.stateful import Stateful
181-
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
182-
from torch.distributed.checkpoint import StorageWriter
180+
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
183181
184182
CHECKPOINT_DIR = "checkpoint"
185183
@@ -199,7 +197,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
199197
200198
def state_dict(self):
201199
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
202-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
200+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
203201
return {
204202
"model": model_state_dict,
205203
"optim": optimizer_state_dict
@@ -230,7 +228,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
230228
os.environ["MASTER_PORT"] = "12355 "
231229
232230
# initialize the process group
233-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
231+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
234232
torch.cuda.set_device(rank)
235233
236234
@@ -244,7 +242,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
244242
245243
# create a model and move it to GPU with id rank
246244
model = ToyModel().to(rank)
247-
model = FSDP(model)
245+
model = fully_shard(model)
248246
249247
loss_fn = nn.MSELoss()
250248
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
@@ -254,7 +252,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
254252
# into a persistent buffer with pinned memory enabled.
255253
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the
256254
# pinned memory buffer.
257-
writer = StorageWriter(cached_state_dict=True)
255+
writer = StorageWriter(cache_staged_state_dict=True, path=CHECKPOINT_DIR)
258256
checkpoint_future = None
259257
for step in range(10):
260258
optimizer.zero_grad()

recipes_source/distributed_checkpoint_recipe.rst

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,9 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
5959
import torch.multiprocessing as mp
6060
import torch.nn as nn
6161
62-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
62+
from torch.distributed.fsdp import fully_shard
6363
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
6464
from torch.distributed.checkpoint.stateful import Stateful
65-
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
6665
6766
CHECKPOINT_DIR = "checkpoint"
6867
@@ -127,7 +126,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
127126
128127
# create a model and move it to GPU with id rank
129128
model = ToyModel().to(rank)
130-
model = FSDP(model)
129+
model = fully_shard(model)
131130
132131
loss_fn = nn.MSELoss()
133132
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
@@ -152,7 +151,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input
152151
join=True,
153152
)
154153
155-
Please go ahead and check the `checkpoint` directory. You should see 8 checkpoint files as shown below.
154+
Please go ahead and check the `checkpoint` directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files.
156155

157156
.. figure:: /_static/img/distributed/distributed_checkpoint_generated_files.png
158157
:width: 100%
@@ -183,7 +182,7 @@ The reason that we need the ``state_dict`` prior to loading is:
183182
import torch.multiprocessing as mp
184183
import torch.nn as nn
185184
186-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
185+
from torch.distributed.fsdp import fully_shard
187186
188187
CHECKPOINT_DIR = "checkpoint"
189188
@@ -248,7 +247,7 @@ The reason that we need the ``state_dict`` prior to loading is:
248247
249248
# create a model and move it to GPU with id rank
250249
model = ToyModel().to(rank)
251-
model = FSDP(model)
250+
model = fully_shard(model)
252251
253252
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
254253

0 commit comments

Comments
 (0)