From d8e34297dd5d2ea6fe7cff89402bdc0e304fae01 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Mon, 16 Jun 2025 01:08:51 -0700 Subject: [PATCH] add tensorboard to training script Summary: - add tensorboard integration and separate the metrics by run id and replica id - have an output folder per replica id Test Plan: image --- .gitignore | 2 ++ train_diloco.py | 29 +++++++++++++++++++++++------ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index e4feba58..65e56c89 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ dist/ # Torch cifar/ + +output/ diff --git a/train_diloco.py b/train_diloco.py index 58654f2b..c221558a 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -23,6 +23,7 @@ from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.pipelining import SplitPoint, pipeline from torch.export import export +from torch.utils.tensorboard import SummaryWriter from torchdata.stateful_dataloader import StatefulDataLoader from torchft import ( @@ -41,7 +42,11 @@ @record def main() -> None: REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) - NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + RUN = int(os.environ.get("RUN", 0)) + + output_folder = f"output/replica-{REPLICA_GROUP_ID}" + + writer = SummaryWriter(f"{output_folder}/tensorboard", max_queue=1000) def load_state_dict(state_dict): m.load_state_dict(state_dict["model"]) @@ -171,12 +176,12 @@ def forward(self, x): num_params = sum(p.numel() for p in m.parameters()) print(f"Total number of parameters: {num_params}") - sort_by_keyword = "self_" + device + "_time_total" - def trace_handler(p): - p.export_chrome_trace( - f"/home/tushar00jain/trace_{p.step_num}_{REPLICA_GROUP_ID}.json" - ) + dir = f"{output_folder}/profiles" + if not os.path.exists(dir): + os.makedirs(dir, exist_ok=True) + + p.export_chrome_trace(f"{dir}/step-{p.step_num}.json") # You can use an epoch based training but with faults it's easier to use step # based training. @@ -188,6 +193,7 @@ def trace_handler(p): ) prof.start() + tensorboard_key_prefix = f"Run:{RUN}" with DiLoCo( manager, module_partitions if USE_STREAMING else [m], @@ -210,16 +216,27 @@ def trace_handler(p): out = m(inputs) loss = criterion(out, labels) + writer.add_scalar(f"{tensorboard_key_prefix}/loss", loss, i) + loss.backward() inner_optimizer.step() + writer.add_scalar( + f"{tensorboard_key_prefix}/num_participants", + manager.num_participants(), + i, + ) + writer.add_scalar( + f"{tensorboard_key_prefix}/current_step", manager.current_step(), i + ) if manager.current_step() % 100 == 0: print(f"[{manager.current_step()}] loss = {loss.item()}") if manager.current_step() >= 15: # complete training prof.stop() + writer.flush() exit()