@@ -51,10 +51,9 @@ Specifically:
51
51
import torch.multiprocessing as mp
52
52
import torch.nn as nn
53
53
54
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
54
+ from torch.distributed.fsdp import fully_shard
55
55
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
56
56
from torch.distributed.checkpoint.stateful import Stateful
57
- from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
58
57
59
58
CHECKPOINT_DIR = " checkpoint"
60
59
@@ -74,7 +73,7 @@ Specifically:
74
73
75
74
def state_dict (self ):
76
75
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
77
- model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
76
+ model_state_dict, optimizer_state_dict = get_state_dict(self . model, self . optimizer)
78
77
return {
79
78
" model" : model_state_dict,
80
79
" optim" : optimizer_state_dict
@@ -105,7 +104,7 @@ Specifically:
105
104
os.environ[" MASTER_PORT" ] = " 12355 "
106
105
107
106
# initialize the process group
108
- dist.init_process_group(" nccl " , rank = rank, world_size = world_size)
107
+ dist.init_process_group(" gloo " , rank = rank, world_size = world_size)
109
108
torch.cuda.set_device(rank)
110
109
111
110
@@ -119,7 +118,7 @@ Specifically:
119
118
120
119
# create a model and move it to GPU with id rank
121
120
model = ToyModel().to(rank)
122
- model = FSDP (model)
121
+ model = fully_shard (model)
123
122
124
123
loss_fn = nn.MSELoss()
125
124
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 )
@@ -158,9 +157,9 @@ Specifically, this optimization attacks the main overhead of asynchronous checkp
158
157
checkpoint requests users can take advantage of direct memory access to speed up this copy.
159
158
160
159
.. note ::
161
- The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
162
- the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
163
- checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
160
+ The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
161
+ the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
162
+ checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
164
163
leading to the same
165
164
peak memory pressure being sustained through the application life.
166
165
@@ -175,11 +174,10 @@ checkpoint requests users can take advantage of direct memory access to speed up
175
174
import torch.multiprocessing as mp
176
175
import torch.nn as nn
177
176
178
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
177
+ from torch.distributed.fsdp import fully_shard
179
178
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
180
179
from torch.distributed.checkpoint.stateful import Stateful
181
- from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
182
- from torch.distributed.checkpoint import StorageWriter
180
+ from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
183
181
184
182
CHECKPOINT_DIR = " checkpoint"
185
183
@@ -199,7 +197,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
199
197
200
198
def state_dict (self ):
201
199
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
202
- model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
200
+ model_state_dict, optimizer_state_dict = get_state_dict(self . model, self . optimizer)
203
201
return {
204
202
" model" : model_state_dict,
205
203
" optim" : optimizer_state_dict
@@ -230,7 +228,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
230
228
os.environ[" MASTER_PORT" ] = " 12355 "
231
229
232
230
# initialize the process group
233
- dist.init_process_group(" nccl " , rank = rank, world_size = world_size)
231
+ dist.init_process_group(" gloo " , rank = rank, world_size = world_size)
234
232
torch.cuda.set_device(rank)
235
233
236
234
@@ -244,7 +242,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
244
242
245
243
# create a model and move it to GPU with id rank
246
244
model = ToyModel().to(rank)
247
- model = FSDP (model)
245
+ model = fully_shard (model)
248
246
249
247
loss_fn = nn.MSELoss()
250
248
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 )
@@ -254,7 +252,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
254
252
# into a persistent buffer with pinned memory enabled.
255
253
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the
256
254
# pinned memory buffer.
257
- writer = StorageWriter(cached_state_dict = True )
255
+ writer = StorageWriter(cache_staged_state_dict = True , path = CHECKPOINT_DIR )
258
256
checkpoint_future = None
259
257
for step in range (10 ):
260
258
optimizer.zero_grad()
0 commit comments