diff --git a/distributed/ddp/README.md b/distributed/ddp/README.md index 7b34a25354..4f110affe6 100644 --- a/distributed/ddp/README.md +++ b/distributed/ddp/README.md @@ -1,167 +1,129 @@ -# Launching and configuring distributed data parallel applications -In this tutorial we will demonstrate how to structure a distributed -model training application so it can be launched conveniently on -multiple nodes, each with multiple GPUs using PyTorch's distributed -[launcher script](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py). +# Distributed Data Parallel (DDP) Applications with PyTorch -# Prerequisites +This guide demonstrates how to structure a distributed model training application for convenient multi-node launches using `torchrun`. -We assume you are familiar with [PyTorch](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html), the primitives it provides for [writing distributed applications](https://pytorch.org/tutorials/intermediate/dist_tuto.html) as well as training [distributed models](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). +--- -The example program in this tutorial uses the -[`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) class for training models -in a _data parallel_ fashion: multiple workers train the same global -model by processing different portions of a large dataset, computing -local gradients (aka _sub_-gradients) independently and then -collectively synchronizing gradients using the AllReduce primitive. In -HPC terminology, this model of execution is called _Single Program -Multiple Data_ or SPMD since the same application runs on all -application but each one operates on different portions of the -training dataset. +## Prerequisites -# Application process topologies +You should be familiar with: + +- [PyTorch basics](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) +- [Writing distributed applications](https://pytorch.org/tutorials/intermediate/dist_tuto.html) +- [Distributed model training](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) + +This tutorial uses the [`torch.nn.parallel.DistributedDataParallel`](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) (DDP) class for data parallel training: multiple workers train the same global model on different data shards, compute local gradients, and synchronize them using AllReduce. In High-Performance Computing (HCP), this is called _Single Program Multiple Data_ (SPMD). + +--- + +## Application Process Topologies A Distributed Data Parallel (DDP) application can be executed on multiple nodes where each node can consist of multiple GPU devices. Each node in turn can run multiple copies of the DDP application, each of which processes its models on multiple GPUs. -Let _N_ be the number of nodes on which the application is running and -_G_ be the number of GPUs per node. The total number of application -processes running across all the nodes at one time is called the -**World Size**, _W_ and the number of processes running on each node -is referred to as the **Local World Size**, _L_. +Let: +- _N_ = number of nodes +- _G_ = number of GPUs per node +- _W_ = **World Size** = total number of processes +- _L_ = **Local World Size** = processes per node -Each application process is assigned two IDs: a _local_ rank in \[0, -_L_-1\] and a _global_ rank in \[0, _W_-1\]. +Each process has: +- **Local rank**: in `[0, L-1]` +- **Global rank**: in `[0, W-1]` -To illustrate the terminology defined above, consider the case where a -DDP application is launched on two nodes, each of which has four -GPUs. We would then like each process to span two GPUs each. The -mapping of processes to nodes is shown in the figure below: +**Example:** +If you launch a DDP app on 2 nodes, each with 4 GPUs, and want each process to span 2 GPUs, the mapping is as follows: ![ProcessMapping](https://user-images.githubusercontent.com/875518/77676984-4c81e400-6f4c-11ea-87d8-f2ff505a99da.png) -While there are quite a few ways to map processes to nodes, a good -rule of thumb is to have one process span a single GPU. This enables -the DDP application to have as many parallel reader streams as there -are GPUs and in practice provides a good balance between I/O and -computational costs. In the rest of this tutorial, we assume that the -application follows this heuristic. +While there are quite a few ways to map processes to nodes, a good rule of thumb is to have one process span a single GPU. This enables the DDP application to have as many parallel reader streams as there are GPUs and in practice provides a good balance between I/O and computational costs. In the rest of this tutorial, we assume that the application follows this heuristic. # Preparing and launching a DDP application -Independent of how a DDP application is launched, each process needs a -mechanism to know its global and local ranks. Once this is known, all -processes create a `ProcessGroup` that enables them to participate in -collective communication operations such as AllReduce. - -A convenient way to start multiple DDP processes and initialize all -values needed to create a `ProcessGroup` is to use the distributed -`launch.py` script provided with PyTorch. The launcher can be found -under the `distributed` subdirectory under the local `torch` -installation directory. Here is a quick way to get the path of -`launch.py` on any operating system: - -```sh -python -c "from os import path; import torch; print(path.join(path.dirname(torch.__file__), 'distributed', 'launch.py'))" -``` - -This will print something like this: - -```sh -/home/username/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/distributed/launch.py -``` - -When the DDP application is started via `launch.py`, it passes the world size, global rank, master address and master port via environment variables and the local rank as a command-line parameter to each instance. -To use the launcher, an application needs to adhere to the following convention: - -1. It must provide an entry-point function for a _single worker_. For example, it should not launch subprocesses using `torch.multiprocessing.spawn` -2. It must use environment variables for initializing the process group. - -For simplicity, the application can assume each process maps to a single GPU but in the next section we also show how a more general process-to-GPU mapping can be performed. - -# Sample application +Independent of how a DDP application is launched, each process needs a mechanism to know its global and local ranks. Once this is known, all processes create a `ProcessGroup` that enables them to participate in collective communication operations such as AllReduce. -The sample DDP application in this repo is based on the "Hello, World" [DDP tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). +A convenient way to start multiple DDP processes and initialize all values needed to create a `ProcessGroup` is to use the [`torchrun`](https://docs.pytorch.org/docs/stable/elastic/run.html) script provided with PyTorch. -## Argument passing convention +--- -The DDP application takes two command-line arguments: +## Sample Application -1. `--local_rank`: This is passed in via `launch.py` -2. `--local_world_size`: This is passed in explicitly and is typically either $1$ or the number of GPUs per node. +This example is based on the ["Hello, World" DDP tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). -The application parses these and calls the `spmd_main` entrypoint: +The application calls the `spmd_main` entrypoint: -```py +```python if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--local_rank", type=int, default=0) - parser.add_argument("--local_world_size", type=int, default=1) - args = parser.parse_args() - spmd_main(args.local_world_size, args.local_rank) + spmd_main() ``` -In `spmd_main`, the process group is initialized with just the backend (NCCL or Gloo). The rest of the information needed for rendezvous comes from environment variables set by `launch.py`: +In `spmd_main`, the process group is initialized using the Accelerator API. The rest of the rendezvous information comes from environment variables set by `torchrun`: -```py -def spmd_main(local_world_size, local_rank): +```python +def spmd_main(): # These are the parameters used to initialize the process group env_dict = { key: os.environ[key] for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE") } + rank = int(env_dict['RANK']) + local_rank = int(env_dict['LOCAL_RANK']) + local_world_size = int(env_dict['LOCAL_WORLD_SIZE']) + print(f"[{os.getpid()}] Initializing process group with: {env_dict}") - dist.init_process_group(backend="nccl") - print( - f"[{os.getpid()}] world_size = {dist.get_world_size()}, " - + f"rank = {dist.get_rank()}, backend={dist.get_backend()}" - ) + acc = torch.accelerator.current_accelerator() + vendor_backend = torch.distributed.get_default_backend_for_device(acc) + torch.accelerator.set_device_index(rank) + torch.distributed.init_process_group(backend=vendor_backend) - demo_basic(local_world_size, local_rank) + demo_basic(rank) # Tear down the process group - dist.destroy_process_group() + torch.distributed.destroy_process_group() ``` -Given the local rank and world size, the training function, `demo_basic` initializes the `DistributedDataParallel` model across a set of GPUs local to the node via `device_ids`: - -```py -def demo_basic(local_world_size, local_rank): +**Key points:** +- Each process reads its rank and world size from environment variables. +- The process group is initialized for distributed communication. - # setup devices for this process. For local_world_size = 2, num_gpus = 8, - # rank 0 uses GPUs [0, 1, 2, 3] and - # rank 1 uses GPUs [4, 5, 6, 7]. - n = torch.cuda.device_count() // local_world_size - device_ids = list(range(local_rank * n, (local_rank + 1) * n)) +The training function, `demo_basic`, initializes the DDP model on the appropriate GPU: +```python +def demo_basic(rank): print( - f"[{os.getpid()}] rank = {dist.get_rank()}, " - + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids}" + f"[{os.getpid()}] rank = {torch.distributed.get_rank()}, " + + f"world_size = {torch.distributed.get_world_size()}" ) - model = ToyModel().cuda(device_ids[0]) - ddp_model = DDP(model, device_ids) + model = ToyModel().to(rank) + ddp_model = DDP(model, device_ids=[rank]) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) optimizer.zero_grad() outputs = ddp_model(torch.randn(20, 10)) - labels = torch.randn(20, 5).to(device_ids[0]) + labels = torch.randn(20, 5).to(rank) loss_fn(outputs, labels).backward() optimizer.step() ``` -The application can be launched via `launch.py` as follows on a 8 GPU node with one process per GPU: +--- + +## Launching the Application ```sh -python /path/to/launch.py --nnode=1 --node_rank=0 --nproc_per_node=8 example.py --local_world_size=8 +torchrun --nnodes=1 --nproc_per_node=8 example.py ``` -and produces an output similar to the one shown below: +--- + +## Example Output + +Expected output: ```sh ***************************************** @@ -183,30 +145,16 @@ Setting OMP_NUM_THREADS environment variable for each process to be 1 in default [238632] world_size = 8, rank = 5, backend=nccl [238634] world_size = 8, rank = 7, backend=nccl [238627] world_size = 8, rank = 0, backend=nccl -[238633] rank = 6, world_size = 8, n = 1, device_ids = [6] -[238628] rank = 1, world_size = 8, n = 1, device_ids = [1] -[238632] rank = 5, world_size = 8, n = 1, device_ids = [5] -[238634] rank = 7, world_size = 8, n = 1, device_ids = [7] -[238629] rank = 2, world_size = 8, n = 1, device_ids = [2] -[238630] rank = 3, world_size = 8, n = 1, device_ids = [3] -[238631] rank = 4, world_size = 8, n = 1, device_ids = [4] -[238627] rank = 0, world_size = 8, n = 1, device_ids = [0] -``` - -Similarly, it can be launched with a single process that spans all 8 GPUs using: - -```sh -python /path/to/launch.py --nnode=1 --node_rank=0 --nproc_per_node=1 example.py --local_world_size=1 -``` - -that in turn produces the following output - -```sh -[262816] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '0', 'WORLD_SIZE': '1'} -[262816]: world_size = 1, rank = 0, backend=nccl -[262816] rank = 0, world_size = 1, n = 8, device_ids = [0, 1, 2, 3, 4, 5, 6, 7] +[238633] rank = 6, world_size = 8 +[238628] rank = 1, world_size = 8 +[238632] rank = 5, world_size = 8 +[238634] rank = 7, world_size = 8 +[238629] rank = 2, world_size = 8 +[238630] rank = 3, world_size = 8 +[238631] rank = 4, world_size = 8 +[238627] rank = 0, world_size = 8 ``` # Conclusions -As the author of a distributed data parallel application, your code needs to be aware of two types of resources: compute nodes and the GPUs within each node. The process of setting up bookkeeping to track how the set of GPUs is mapped to the processes of your application can be tedious and error-prone. We hope that by structuring your application as shown in this example and using the launcher, the mechanics of setting up distributed training can be significantly simplified. +As the author of a distributed data parallel application, your code needs to be aware of two types of resources: compute nodes and the GPUs within each node. The process of setting up bookkeeping to track how the set of GPUs is mapped to the processes of your application can be tedious and error-prone. We hope that by structuring your application as shown in this example and using `torchrun`, the mechanics of setting up distributed training can be significantly simplified. diff --git a/distributed/ddp/example.py b/distributed/ddp/example.py index 4110fa2268..d64b473976 100644 --- a/distributed/ddp/example.py +++ b/distributed/ddp/example.py @@ -11,6 +11,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP +def verify_min_gpu_count(min_gpus: int = 2) -> bool: + """ verification that we have at least 2 gpus to run dist examples """ + has_gpu = torch.accelerator.is_available() + gpu_count = torch.accelerator.device_count() + return has_gpu and gpu_count >= min_gpus + class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() @@ -22,38 +28,37 @@ def forward(self, x): return self.net2(self.relu(self.net1(x))) -def demo_basic(local_world_size, local_rank): - - # setup devices for this process. For local_world_size = 2, num_gpus = 8, - # rank 0 uses GPUs [0, 1, 2, 3] and - # rank 1 uses GPUs [4, 5, 6, 7]. - n = torch.cuda.device_count() // local_world_size - device_ids = list(range(local_rank * n, (local_rank + 1) * n)) +def demo_basic(rank): print( f"[{os.getpid()}] rank = {dist.get_rank()}, " - + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end='' - ) + + f"world_size = {dist.get_world_size()}" + ) - model = ToyModel().cuda(device_ids[0]) - ddp_model = DDP(model, device_ids) + model = ToyModel().to(rank) + ddp_model = DDP(model, device_ids=[rank]) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) optimizer.zero_grad() outputs = ddp_model(torch.randn(20, 10)) - labels = torch.randn(20, 5).to(device_ids[0]) + labels = torch.randn(20, 5).to(rank) loss_fn(outputs, labels).backward() optimizer.step() + print(f"training completed in rank {rank}!") -def spmd_main(local_world_size, local_rank): + +def main(): # These are the parameters used to initialize the process group env_dict = { key: os.environ[key] - for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE") + for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "LOCAL_RANK", "WORLD_SIZE", "LOCAL_WORLD_SIZE") } + rank = int(env_dict['RANK']) + local_rank = int(env_dict['LOCAL_RANK']) + local_world_size = int(env_dict['LOCAL_WORLD_SIZE']) if sys.platform == "win32": # Distributed package only covers collective communications with Gloo @@ -73,25 +78,24 @@ def spmd_main(local_world_size, local_rank): dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]), world_size=int(env_dict["WORLD_SIZE"])) else: print(f"[{os.getpid()}] Initializing process group with: {env_dict}") - dist.init_process_group(backend="nccl") + acc = torch.accelerator.current_accelerator() + backend = torch.distributed.get_default_backend_for_device(acc) + torch.accelerator.set_device_index(rank) + dist.init_process_group(backend=backend) print( f"[{os.getpid()}]: world_size = {dist.get_world_size()}, " + f"rank = {dist.get_rank()}, backend={dist.get_backend()} \n", end='' ) - demo_basic(local_world_size, local_rank) + demo_basic(rank) # Tear down the process group dist.destroy_process_group() - if __name__ == "__main__": - parser = argparse.ArgumentParser() - # This is passed in via launch.py - parser.add_argument("--local_rank", type=int, default=0) - # This needs to be explicitly passed in - parser.add_argument("--local_world_size", type=int, default=1) - args = parser.parse_args() - # The main entry point is called directly without using subprocess - spmd_main(args.local_world_size, args.local_rank) + _min_gpu_count = 2 + if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() + main() diff --git a/distributed/ddp/main.py b/distributed/ddp/main.py deleted file mode 100644 index 34a855f051..0000000000 --- a/distributed/ddp/main.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import tempfile -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim - -from torch.nn.parallel import DistributedDataParallel as DDP - - -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - - # initialize the process group - dist.init_process_group("gloo", rank=rank, world_size=world_size) - - -def cleanup(): - dist.destroy_process_group() - - -class ToyModel(nn.Module): - def __init__(self): - super(ToyModel, self).__init__() - self.net1 = nn.Linear(10, 10) - self.relu = nn.ReLU() - self.net2 = nn.Linear(10, 5) - - def forward(self, x): - return self.net2(self.relu(self.net1(x))) - - -def demo_basic(rank, world_size): - print(f"Running basic DDP example on rank {rank}.") - setup(rank, world_size) - - # create model and move it to GPU with id rank - model = ToyModel().to(rank) - ddp_model = DDP(model, device_ids=[rank]) - - loss_fn = nn.MSELoss() - optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) - - optimizer.zero_grad() - outputs = ddp_model(torch.randn(20, 10)) - labels = torch.randn(20, 5).to(rank) - loss_fn(outputs, labels).backward() - optimizer.step() - - cleanup() - - -def run_demo(demo_fn, world_size): - mp.spawn(demo_fn, - args=(world_size,), - nprocs=world_size, - join=True) - - -def demo_checkpoint(rank, world_size): - print(f"Running DDP checkpoint example on rank {rank}.") - setup(rank, world_size) - - model = ToyModel().to(rank) - ddp_model = DDP(model, device_ids=[rank]) - - loss_fn = nn.MSELoss() - optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) - - CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint" - if rank == 0: - # All processes should see same parameters as they all start from same - # random parameters and gradients are synchronized in backward passes. - # Therefore, saving it in one process is sufficient. - torch.save(ddp_model.state_dict(), CHECKPOINT_PATH) - - # Use a barrier() to make sure that process 1 loads the model after process - # 0 saves it. - dist.barrier() - # configure map_location properly - map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} - ddp_model.load_state_dict( - torch.load(CHECKPOINT_PATH, map_location=map_location)) - - optimizer.zero_grad() - outputs = ddp_model(torch.randn(20, 10)) - labels = torch.randn(20, 5).to(rank) - loss_fn = nn.MSELoss() - loss_fn(outputs, labels).backward() - optimizer.step() - - # Use a barrier() to make sure that all processes have finished reading the - # checkpoint - dist.barrier() - - if rank == 0: - os.remove(CHECKPOINT_PATH) - - cleanup() - - -class ToyMpModel(nn.Module): - def __init__(self, dev0, dev1): - super(ToyMpModel, self).__init__() - self.dev0 = dev0 - self.dev1 = dev1 - self.net1 = torch.nn.Linear(10, 10).to(dev0) - self.relu = torch.nn.ReLU() - self.net2 = torch.nn.Linear(10, 5).to(dev1) - - def forward(self, x): - x = x.to(self.dev0) - x = self.relu(self.net1(x)) - x = x.to(self.dev1) - return self.net2(x) - - -def demo_model_parallel(rank, world_size): - print(f"Running DDP with model parallel example on rank {rank}.") - setup(rank, world_size) - - # setup mp_model and devices for this process - dev0 = rank * 2 - dev1 = rank * 2 + 1 - mp_model = ToyMpModel(dev0, dev1) - ddp_mp_model = DDP(mp_model) - - loss_fn = nn.MSELoss() - optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001) - - optimizer.zero_grad() - # outputs will be on dev1 - outputs = ddp_mp_model(torch.randn(20, 10)) - labels = torch.randn(20, 5).to(dev1) - loss_fn(outputs, labels).backward() - optimizer.step() - - cleanup() - - -if __name__ == "__main__": - n_gpus = torch.cuda.device_count() - if n_gpus < 8: - print(f"Requires at least 8 GPUs to run, but got {n_gpus}.") - else: - run_demo(demo_basic, 8) - run_demo(demo_checkpoint, 8) - run_demo(demo_model_parallel, 4) diff --git a/distributed/ddp/requirements.txt b/distributed/ddp/requirements.txt index 12c6d5d5ea..285a4d8195 100644 --- a/distributed/ddp/requirements.txt +++ b/distributed/ddp/requirements.txt @@ -1 +1 @@ -torch +torch>=2.7 diff --git a/distributed/ddp/run_example.sh b/distributed/ddp/run_example.sh new file mode 100755 index 0000000000..d439b681b4 --- /dev/null +++ b/distributed/ddp/run_example.sh @@ -0,0 +1,10 @@ +# /bin/bash +# bash run_example.sh {file_to_run.py} {num_gpus} +# where file_to_run = example to run. Default = 'example.py' +# num_gpus = num local gpus to use (must be at least 2). Default = 2 + +# samples to run include: +# example.py + +echo "Launching ${1:-example.py} with ${2:-2} gpus" +torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} diff --git a/run_distributed_examples.sh b/run_distributed_examples.sh index e1f579c072..5bf6f10894 100755 --- a/run_distributed_examples.sh +++ b/run_distributed_examples.sh @@ -51,7 +51,7 @@ function distributed_tensor_parallelism() { } function distributed_ddp() { - uv run main.py || error "ddp example failed" + uv run bash run_example.sh example.py || error "ddp example failed" } function run_all() {