Skip to content

Commit 1a26764

Browse files
authored
Add torch.accelerator API to mingGPT example (#34)
* Add torch.accelerator API to mingGPT example
1 parent d47f0f3 commit 1a26764

File tree

5 files changed

+38
-8
lines changed

5 files changed

+38
-8
lines changed

distributed/minGPT-ddp/mingpt/main.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23
import torch
34
from torch.utils.data import random_split
45
from torch.distributed import init_process_group, destroy_process_group
@@ -8,10 +9,18 @@
89
from omegaconf import DictConfig
910
import hydra
1011

12+
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
13+
has_gpu = torch.accelerator.is_available()
14+
gpu_count = torch.accelerator.device_count()
15+
return has_gpu and gpu_count >= min_gpus
1116

1217
def ddp_setup():
13-
init_process_group(backend="nccl")
14-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
18+
acc = torch.accelerator.current_accelerator()
19+
rank = int(os.environ["LOCAL_RANK"])
20+
device: torch.device = torch.device(f"{acc}:{rank}")
21+
backend = torch.distributed.get_default_backend_for_device(device)
22+
init_process_group(backend=backend)
23+
torch.accelerator.set_device_index(rank)
1524

1625
def get_train_objs(gpt_cfg: GPTConfig, opt_cfg: OptimizerConfig, data_cfg: DataConfig):
1726
dataset = CharDataset(data_cfg)
@@ -40,6 +49,9 @@ def main(cfg: DictConfig):
4049

4150
destroy_process_group()
4251

43-
4452
if __name__ == "__main__":
53+
_min_gpu_count = 2
54+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
55+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
56+
sys.exit()
4557
main()

distributed/minGPT-ddp/mingpt/trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
4848
# set torchrun variables
4949
self.local_rank = int(os.environ["LOCAL_RANK"])
5050
self.global_rank = int(os.environ["RANK"])
51+
# set device
52+
self.acc = torch.accelerator.current_accelerator()
53+
self.device: torch.device = torch.device(f"{self.acc}:{self.local_rank}")
54+
self.device_type = self.device.type
5155
# data stuff
5256
self.train_dataset = train_dataset
5357
self.train_loader = self._prepare_dataloader(train_dataset)
@@ -58,7 +62,7 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
5862
self.optimizer = optimizer
5963
self.save_every = self.config.save_every
6064
if self.config.use_amp:
61-
self.scaler = torch.cuda.amp.GradScaler()
65+
self.scaler = torch.amp.GradScaler(self.device_type)
6266
# load snapshot if available. only necessary on the first node.
6367
if self.config.snapshot_path is None:
6468
self.config.snapshot_path = "snapshot.pt"
@@ -93,7 +97,7 @@ def _load_snapshot(self):
9397

9498

9599
def _run_batch(self, source, targets, train: bool = True) -> float:
96-
with torch.set_grad_enabled(train), torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=(self.config.use_amp)):
100+
with torch.set_grad_enabled(train), torch.amp.autocast(device_type=self.device_type, dtype=torch.float16, enabled=(self.config.use_amp)):
97101
_, loss = self.model(source, targets)
98102

99103
if train:
@@ -119,7 +123,7 @@ def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
119123
targets = targets.to(self.local_rank)
120124
batch_loss = self._run_batch(source, targets, train)
121125
if iter % 100 == 0:
122-
print(f"[GPU{self.global_rank}] Epoch {epoch} | Iter {iter} | {step_type} Loss {batch_loss:.5f}")
126+
print(f"[RANK{self.global_rank}] Epoch {epoch} | Iter {iter} | {step_type} Loss {batch_loss:.5f}")
123127

124128
def _save_snapshot(self, epoch):
125129
# capture snapshot

distributed/minGPT-ddp/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
torch>=1.11.0
2-
fsspec
1+
torch>=2.7
32
boto3
43
hydra-core
54
requests

distributed/minGPT-ddp/run_example.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
# Usage: bash run_example.sh {file_to_run.py} {num_gpus}
3+
# where file_to_run = example to run. Default = 'main.py'
4+
# num_gpus = num local gpus to use. Default = 16
5+
6+
# samples to run include:
7+
# main.py
8+
9+
echo "Launching ${1:-main.py} with ${2:-16} gpus"
10+
torchrun --standalone --nproc_per_node=${2:-16} ${1:-main.py}

run_distributed_examples.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ function distributed_ddp() {
5454
uv run main.py || error "ddp example failed"
5555
}
5656

57+
function distributed_minGPT-ddp() {
58+
uv run bash run_example.sh mingpt/main.py || error "minGPT example failed"
59+
}
60+
5761
function run_all() {
5862
run distributed/tensor_parallelism
5963
run distributed/ddp
64+
run distributed/minGPT-ddp
6065
}
6166

6267
# by default, run all examples

0 commit comments

Comments
 (0)