diff --git a/distributed/minGPT-ddp/mingpt/main.py b/distributed/minGPT-ddp/mingpt/main.py index ac6bac0807..8fdca56667 100644 --- a/distributed/minGPT-ddp/mingpt/main.py +++ b/distributed/minGPT-ddp/mingpt/main.py @@ -1,4 +1,5 @@ import os +import sys import torch from torch.utils.data import random_split from torch.distributed import init_process_group, destroy_process_group @@ -8,10 +9,18 @@ from omegaconf import DictConfig import hydra +def verify_min_gpu_count(min_gpus: int = 2) -> bool: + has_gpu = torch.accelerator.is_available() + gpu_count = torch.accelerator.device_count() + return has_gpu and gpu_count >= min_gpus def ddp_setup(): - init_process_group(backend="nccl") - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + acc = torch.accelerator.current_accelerator() + rank = int(os.environ["LOCAL_RANK"]) + device: torch.device = torch.device(f"{acc}:{rank}") + backend = torch.distributed.get_default_backend_for_device(device) + init_process_group(backend=backend) + torch.accelerator.set_device_index(rank) def get_train_objs(gpt_cfg: GPTConfig, opt_cfg: OptimizerConfig, data_cfg: DataConfig): dataset = CharDataset(data_cfg) @@ -40,6 +49,9 @@ def main(cfg: DictConfig): destroy_process_group() - if __name__ == "__main__": + _min_gpu_count = 2 + if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() main() diff --git a/distributed/minGPT-ddp/mingpt/trainer.py b/distributed/minGPT-ddp/mingpt/trainer.py index 4d30695d41..1fbc457060 100644 --- a/distributed/minGPT-ddp/mingpt/trainer.py +++ b/distributed/minGPT-ddp/mingpt/trainer.py @@ -48,6 +48,10 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase # set torchrun variables self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) + # set device + self.acc = torch.accelerator.current_accelerator() + self.device: torch.device = torch.device(f"{self.acc}:{self.local_rank}") + self.device_type = self.device.type # data stuff self.train_dataset = train_dataset self.train_loader = self._prepare_dataloader(train_dataset) @@ -58,7 +62,7 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase self.optimizer = optimizer self.save_every = self.config.save_every if self.config.use_amp: - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = torch.amp.GradScaler(self.device_type) # load snapshot if available. only necessary on the first node. if self.config.snapshot_path is None: self.config.snapshot_path = "snapshot.pt" @@ -93,7 +97,7 @@ def _load_snapshot(self): def _run_batch(self, source, targets, train: bool = True) -> float: - with torch.set_grad_enabled(train), torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=(self.config.use_amp)): + with torch.set_grad_enabled(train), torch.amp.autocast(device_type=self.device_type, dtype=torch.float16, enabled=(self.config.use_amp)): _, loss = self.model(source, targets) if train: @@ -119,7 +123,7 @@ def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True): targets = targets.to(self.local_rank) batch_loss = self._run_batch(source, targets, train) if iter % 100 == 0: - print(f"[GPU{self.global_rank}] Epoch {epoch} | Iter {iter} | {step_type} Loss {batch_loss:.5f}") + print(f"[RANK{self.global_rank}] Epoch {epoch} | Iter {iter} | {step_type} Loss {batch_loss:.5f}") def _save_snapshot(self, epoch): # capture snapshot diff --git a/distributed/minGPT-ddp/requirements.txt b/distributed/minGPT-ddp/requirements.txt index 03872eca88..61162c6320 100644 --- a/distributed/minGPT-ddp/requirements.txt +++ b/distributed/minGPT-ddp/requirements.txt @@ -1,5 +1,4 @@ -torch>=1.11.0 -fsspec +torch>=2.7 boto3 hydra-core requests diff --git a/distributed/minGPT-ddp/run_example.sh b/distributed/minGPT-ddp/run_example.sh new file mode 100755 index 0000000000..8e8a0acb57 --- /dev/null +++ b/distributed/minGPT-ddp/run_example.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Usage: bash run_example.sh {file_to_run.py} {num_gpus} +# where file_to_run = example to run. Default = 'main.py' +# num_gpus = num local gpus to use. Default = 16 + +# samples to run include: +# main.py + +echo "Launching ${1:-main.py} with ${2:-16} gpus" +torchrun --standalone --nproc_per_node=${2:-16} ${1:-main.py} diff --git a/run_distributed_examples.sh b/run_distributed_examples.sh index e1f579c072..39c9431500 100755 --- a/run_distributed_examples.sh +++ b/run_distributed_examples.sh @@ -54,9 +54,14 @@ function distributed_ddp() { uv run main.py || error "ddp example failed" } +function distributed_minGPT-ddp() { + uv run bash run_example.sh mingpt/main.py || error "minGPT example failed" +} + function run_all() { run distributed/tensor_parallelism run distributed/ddp + run distributed/minGPT-ddp } # by default, run all examples