Skip to content

Update ddp_minGPT to remove FSDP1 references #3442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions intermediate_source/ddp_series_minGPT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ Authors: `Suraj Subramanian <https://github.com/subramen>`__
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites

- Familiarity with `multi-GPU training <../beginner/ddp_series_multigpu.html>`__ and `torchrun <../beginner/ddp_series_fault_tolerance.html>`__
- [Optional] Familiarity with `multinode training <ddp_series_multinode.html>`__
- 2 or more TCP-reachable GPU machines (this tutorial uses AWS p3.2xlarge instances)
- PyTorch `installed <https://pytorch.org/get-started/locally/>`__ with CUDA on all machines
- Familiarity with `multi-GPU training <../beginner/ddp_series_multigpu.html>`__ and `torchrun <../beginner/ddp_series_fault_tolerance.html>`__
- [Optional] Familiarity with `multinode training <ddp_series_multinode.html>`__
- 2 or more TCP-reachable GPU machines for multi-node training (this tutorial uses AWS p3.2xlarge instances)


Follow along with the video below or on `youtube <https://www.youtube.com/watch/XFsFDGKZHh4>`__.

Expand Down Expand Up @@ -63,25 +64,23 @@ from any node that has access to the cloud bucket.

Using Mixed Precision
~~~~~~~~~~~~~~~~~~~~~~~~
To speed things up, you might be able to use `Mixed Precision <https://pytorch.org/docs/stable/amp.html>`__ to train your models.
In Mixed Precision, some parts of the training process are carried out in reduced precision, while other steps
that are more sensitive to precision drops are maintained in FP32 precision.
To speed things up, you might be able to use `Mixed Precision <https://pytorch.org/docs/stable/amp.html>`__ to train your models.
In Mixed Precision, some parts of the training process are carried out in reduced precision, while other steps
that are more sensitive to precision drops are maintained in FP32 precision.


When is DDP not enough?
~~~~~~~~~~~~~~~~~~~~~~~~
A typical training run's memory footprint consists of model weights, activations, gradients, the input batch, and the optimizer state.
Since DDP replicates the model on each GPU, it only works when GPUs have sufficient capacity to accomodate the full footprint.
Since DDP replicates the model on each GPU, it only works when GPUs have sufficient capacity to accomodate the full footprint.
When models grow larger, more aggressive techniques might be useful:

- `activation checkpointing <https://pytorch.org/docs/stable/checkpoint.html>`__: Instead of saving intermediate activations during the forward pass, the activations are recomputed during the backward pass. In this approach, we run more compute but save on memory footprint.
- `Fully-Sharded Data Parallel <https://pytorch.org/docs/stable/fsdp.html>`__: Here the model is not replicated but "sharded" across all the GPUs, and computation is overlapped with communication in the forward and backward passes. Read our `blog <https://medium.com/pytorch/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__ to learn how we trained a 1 Trillion parameter model with FSDP.

- `Activation checkpointing <https://pytorch.org/docs/stable/checkpoint.html>`__: Instead of saving intermediate activations during the forward pass, the activations are recomputed during the backward pass. In this approach, we run more compute but save on memory footprint.
- `Fully-Sharded Data Parallel <https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html>`__: Here the model is not replicated but "sharded" across all the GPUs, and computation is overlapped with communication in the forward and backward passes. Read our `blog <https://medium.com/pytorch/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__ to learn how we trained a 1 Trillion parameter model with FSDP.

Further Reading
---------------
- `Multi-Node training with DDP <ddp_series_multinode.html>`__ (previous tutorial in this series)
- `Mixed Precision training <https://pytorch.org/docs/stable/amp.html>`__
- `Fully-Sharded Data Parallel <https://pytorch.org/docs/stable/fsdp.html>`__
- `Fully-Sharded Data Parallel tutorial <https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__
- `Training a 1T parameter model with FSDP <https://medium.com/pytorch/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__
- `FSDP Video Tutorial Series <https://www.youtube.com/playlist?list=PL_lsbAsL_o2BT6aerEKgIoufVD_fodnuT>`__