diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 9bc7b5f3..dde67b52 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -414,13 +414,26 @@ def test_diloco_recovery(self, use_cuda: bool) -> None: rep0, rep1 = state_dicts for step in rep0.keys(): - # Inner optimizer will be different, outer optimizer and model should be the same + # Inner optimizer and local model parameters will be different e.g. + # with 2 replicas r1 and r2, we sync every 2 steps + # + # - Manager Step 1 + # - Step 1: r1 and r2 step + # - Step 2: r1 and r2 step, sync the model, quorum succeeds + # - Manager Step 2 + # - Step 1: r1 steps but r2 fails + # - Step 2: + # - r1 steps, sync fails because r2 is down + # - r1 recovers r2 from the model state at this step + # that is different from the model for r1 at the beginning + # of step Manager Step 2 + # + # Outer optimizer and global model should be the same + torch.testing.assert_close( - rep1[step]["model"], - rep0[step]["model"], + rep1[step]["original_params"], + rep0[step]["original_params"], check_device=False, - rtol=1e-4, - atol=1e-4, ) torch.testing.assert_close( rep1[step]["outer_optim"], diff --git a/torchft/manager.py b/torchft/manager.py index 27fdecd7..8a4590f3 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -219,6 +219,9 @@ def __init__( torch.cuda.Stream() if torch.cuda.is_available() else None ) + # Used to synchronize recovery operation + self._recovery_event: Optional[torch.cuda.Event] = None + if self._group_rank == 0: if port is None: port = int(os.environ.get(MANAGER_PORT_ENV, 0)) @@ -323,6 +326,7 @@ def allreduce( return fut self.wait_quorum() + num_participants: int = self.num_participants() if not self.is_participating(): tensor.zero_() @@ -337,6 +341,7 @@ def allreduce( ) else: work = self._pg.allreduce([tensor], ReduceOp.SUM) + work.wait() fut = work.get_future() stream: Optional[torch.cuda.Stream] = ( @@ -349,13 +354,13 @@ def allreduce( def callback( fut: torch.futures.Future[List[torch.Tensor]], ) -> torch.Tensor: - nonlocal tensor, stream + nonlocal tensor, stream, num_participants # change the stream to avoid making the callback stream # dependent on process group stream running the allreduce with torch.cuda.stream(stream) if stream is not None else nullcontext(): fut.value() - tensor /= self.num_participants() + tensor /= num_participants return tensor @@ -644,7 +649,12 @@ def _async_quorum( except Exception as e: self._logger.exception(f"got exception in recovery: {e}") self.report_error(e) - return + + self._recovery_event = ( + torch.cuda.current_stream().record_event() + if recovery_stream is not None + else None + ) def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" @@ -704,8 +714,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: with torch.profiler.record_function( "torchft::manager::should_commmit::recovery_stream::synchronize" ): - if self._recovery_stream is not None: - self._recovery_stream.synchronize() + if self._recovery_event is not None: + self._recovery_event.synchronize() + self._recovery_event = None with torch.profiler.record_function( "torchft::manager::should_commit::current_stream::synchronize"