diff --git a/recipes_source/distributed_async_checkpoint_recipe.rst b/recipes_source/distributed_async_checkpoint_recipe.rst index a7194f6c58..0e9add5148 100644 --- a/recipes_source/distributed_async_checkpoint_recipe.rst +++ b/recipes_source/distributed_async_checkpoint_recipe.rst @@ -51,10 +51,9 @@ Specifically: import torch.multiprocessing as mp import torch.nn as nn - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import fully_shard from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful - from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType CHECKPOINT_DIR = "checkpoint" @@ -74,7 +73,7 @@ Specifically: def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict @@ -105,7 +104,7 @@ Specifically: os.environ["MASTER_PORT"] = "12355 " # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) + dist.init_process_group("gloo", rank=rank, world_size=world_size) torch.cuda.set_device(rank) @@ -119,7 +118,7 @@ Specifically: # create a model and move it to GPU with id rank model = ToyModel().to(rank) - model = FSDP(model) + model = fully_shard(model) loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.1) @@ -158,9 +157,9 @@ Specifically, this optimization attacks the main overhead of asynchronous checkp checkpoint requests users can take advantage of direct memory access to speed up this copy. .. note:: - The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without - the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as - checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps, + The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without + the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as + checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps, leading to the same peak memory pressure being sustained through the application life. @@ -175,11 +174,10 @@ checkpoint requests users can take advantage of direct memory access to speed up import torch.multiprocessing as mp import torch.nn as nn - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import fully_shard from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful - from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType - from torch.distributed.checkpoint import StorageWriter + from torch.distributed.checkpoint import FileSystemWriter as StorageWriter CHECKPOINT_DIR = "checkpoint" @@ -199,7 +197,7 @@ checkpoint requests users can take advantage of direct memory access to speed up def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict @@ -230,7 +228,7 @@ checkpoint requests users can take advantage of direct memory access to speed up os.environ["MASTER_PORT"] = "12355 " # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) + dist.init_process_group("gloo", rank=rank, world_size=world_size) torch.cuda.set_device(rank) @@ -244,7 +242,7 @@ checkpoint requests users can take advantage of direct memory access to speed up # create a model and move it to GPU with id rank model = ToyModel().to(rank) - model = FSDP(model) + model = fully_shard(model) loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.1) @@ -254,7 +252,7 @@ checkpoint requests users can take advantage of direct memory access to speed up # into a persistent buffer with pinned memory enabled. # Note: It's important that the writer persists in between checkpointing requests, since it maintains the # pinned memory buffer. - writer = StorageWriter(cached_state_dict=True) + writer = StorageWriter(cache_staged_state_dict=True, path=CHECKPOINT_DIR) checkpoint_future = None for step in range(10): optimizer.zero_grad() diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index 8a81d63bb6..de31d43040 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -59,10 +59,9 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input import torch.multiprocessing as mp import torch.nn as nn - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import fully_shard from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful - from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType CHECKPOINT_DIR = "checkpoint" @@ -127,7 +126,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input # create a model and move it to GPU with id rank model = ToyModel().to(rank) - model = FSDP(model) + model = fully_shard(model) loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.1) @@ -152,7 +151,7 @@ Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input join=True, ) -Please go ahead and check the `checkpoint` directory. You should see 8 checkpoint files as shown below. +Please go ahead and check the `checkpoint` directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files. .. figure:: /_static/img/distributed/distributed_checkpoint_generated_files.png :width: 100% @@ -183,7 +182,7 @@ The reason that we need the ``state_dict`` prior to loading is: import torch.multiprocessing as mp import torch.nn as nn - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import fully_shard CHECKPOINT_DIR = "checkpoint" @@ -248,7 +247,7 @@ The reason that we need the ``state_dict`` prior to loading is: # create a model and move it to GPU with id rank model = ToyModel().to(rank) - model = FSDP(model) + model = fully_shard(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.1)