diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 9bc7b5f3..dde67b52 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -414,13 +414,26 @@ def test_diloco_recovery(self, use_cuda: bool) -> None: rep0, rep1 = state_dicts for step in rep0.keys(): - # Inner optimizer will be different, outer optimizer and model should be the same + # Inner optimizer and local model parameters will be different e.g. + # with 2 replicas r1 and r2, we sync every 2 steps + # + # - Manager Step 1 + # - Step 1: r1 and r2 step + # - Step 2: r1 and r2 step, sync the model, quorum succeeds + # - Manager Step 2 + # - Step 1: r1 steps but r2 fails + # - Step 2: + # - r1 steps, sync fails because r2 is down + # - r1 recovers r2 from the model state at this step + # that is different from the model for r1 at the beginning + # of step Manager Step 2 + # + # Outer optimizer and global model should be the same + torch.testing.assert_close( - rep1[step]["model"], - rep0[step]["model"], + rep1[step]["original_params"], + rep0[step]["original_params"], check_device=False, - rtol=1e-4, - atol=1e-4, ) torch.testing.assert_close( rep1[step]["outer_optim"],