Skip to content

Add torch.accelerator API to minGPT-ddp example #1370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions distributed/minGPT-ddp/mingpt/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import torch
from torch.utils.data import random_split
from torch.distributed import init_process_group, destroy_process_group
Expand All @@ -8,10 +9,18 @@
from omegaconf import DictConfig
import hydra

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus

def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
acc = torch.accelerator.current_accelerator()
rank = int(os.environ["LOCAL_RANK"])
device: torch.device = torch.device(f"{acc}:{rank}")
backend = torch.distributed.get_default_backend_for_device(device)
init_process_group(backend=backend)
torch.accelerator.set_device_index(rank)

def get_train_objs(gpt_cfg: GPTConfig, opt_cfg: OptimizerConfig, data_cfg: DataConfig):
dataset = CharDataset(data_cfg)
Expand Down Expand Up @@ -40,6 +49,9 @@ def main(cfg: DictConfig):

destroy_process_group()


if __name__ == "__main__":
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
main()
10 changes: 7 additions & 3 deletions distributed/minGPT-ddp/mingpt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
# set torchrun variables
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
# set device
self.acc = torch.accelerator.current_accelerator()
self.device: torch.device = torch.device(f"{self.acc}:{self.local_rank}")
self.device_type = self.device.type
# data stuff
self.train_dataset = train_dataset
self.train_loader = self._prepare_dataloader(train_dataset)
Expand All @@ -58,7 +62,7 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
self.optimizer = optimizer
self.save_every = self.config.save_every
if self.config.use_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.scaler = torch.amp.GradScaler(self.device_type)
# load snapshot if available. only necessary on the first node.
if self.config.snapshot_path is None:
self.config.snapshot_path = "snapshot.pt"
Expand Down Expand Up @@ -93,7 +97,7 @@ def _load_snapshot(self):


def _run_batch(self, source, targets, train: bool = True) -> float:
with torch.set_grad_enabled(train), torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=(self.config.use_amp)):
with torch.set_grad_enabled(train), torch.amp.autocast(device_type=self.device_type, dtype=torch.float16, enabled=(self.config.use_amp)):
_, loss = self.model(source, targets)

if train:
Expand All @@ -119,7 +123,7 @@ def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
targets = targets.to(self.local_rank)
batch_loss = self._run_batch(source, targets, train)
if iter % 100 == 0:
print(f"[GPU{self.global_rank}] Epoch {epoch} | Iter {iter} | {step_type} Loss {batch_loss:.5f}")
print(f"[RANK{self.global_rank}] Epoch {epoch} | Iter {iter} | {step_type} Loss {batch_loss:.5f}")

def _save_snapshot(self, epoch):
# capture snapshot
Expand Down
3 changes: 1 addition & 2 deletions distributed/minGPT-ddp/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
torch>=1.11.0
fsspec
torch>=2.7
boto3
hydra-core
requests
Expand Down
10 changes: 10 additions & 0 deletions distributed/minGPT-ddp/run_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
# Usage: bash run_example.sh {file_to_run.py} {num_gpus}
# where file_to_run = example to run. Default = 'main.py'
# num_gpus = num local gpus to use. Default = 16

# samples to run include:
# main.py

echo "Launching ${1:-main.py} with ${2:-16} gpus"
torchrun --standalone --nproc_per_node=${2:-16} ${1:-main.py}
5 changes: 5 additions & 0 deletions run_distributed_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ function distributed_ddp() {
uv run main.py || error "ddp example failed"
}

function distributed_minGPT-ddp() {
uv run bash run_example.sh mingpt/main.py || error "minGPT example failed"
}

function run_all() {
run distributed/tensor_parallelism
run distributed/ddp
run distributed/minGPT-ddp
}

# by default, run all examples
Expand Down