Skip to content

Commit 0bdb164

Browse files
committed
fix infinite recovery
Summary: - we don't increase the max_step when a node is catching up because we don't call should_commit - this can lead the node always being behind and get into an infinite recovery loop - so simply call `should_commit` - note, this can result in the global parameters falling out of sync, the diff includes an RFC on how to fix that - document another case where `should_commit` can return `True` but it shouldn't because allreduce failed Test Plan: - tested on a cluster of 3 nodes by removing and adding a node - the `max_step` and `local_step` increase in the manager's state dict after both failure and recovery metrics from the healthy node <img width="1103" alt="Screenshot 2025-06-15 at 10 53 28 PM copy" src="https://github.com/user-attachments/assets/8640780c-fd20-4266-aa3c-3116776a9c68" /> metrics from the failed and recovered node <img width="1101" alt="Screenshot 2025-06-15 at 10 56 49 PM copy" src="https://github.com/user-attachments/assets/cc2a1c57-715f-4e0a-8e00-7c62da525dc3" />
1 parent 63759df commit 0bdb164

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

torchft/local_sgd.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,39 @@ def perform_sync(self) -> bool:
357357
steps using the outer optimizer.
358358
"""
359359
if len(self._allreduce_futures) == 0:
360-
return True
360+
assert self._fragment_sync_delay > 0
361+
# This can happen when using `fragment_sync_delay`. The node
362+
# might not have participated in syncing of this fragment.
363+
#
364+
# The allreduce for other nodes who did might actually
365+
# succeed and in that case, we shouldn't allow recovery
366+
# from this node.
367+
#
368+
# We do need to increase the `max_step` here so we
369+
# don't end up in an infinite loop of needing to recover.
370+
#
371+
# TODO: We can add a `is_catching_up` flag to the state_dict
372+
# to disallow recoveries from this node. Such nodes can
373+
# be excluded from `max_step` calculation unless all
374+
# nodes are catching up.
375+
return self._manager.should_commit()
361376

362377
self.wait()
363378

364379
# Restore the parameters back to the previous state
365380
self.restore_parameters()
366381

382+
# This can return success even if the allreduce failed. Because
383+
# the process group could have been reconfigured while the
384+
# allreduce was inflight. The inflight allreduce may or may
385+
# not have been aborted.
386+
#
387+
# We consider it successful anyway.
388+
#
389+
# TODO: We can track errors per allreduce to
390+
# let the commit fail here. But this has the downside of
391+
# reconfiguring the pg too many times resulting in
392+
# more aborts and more commit failures.
367393
should_commit = self._manager.should_commit()
368394

369395
if should_commit:
@@ -708,6 +734,16 @@ def _step_post_hook(
708734
# waste after recovery
709735
self._quorum_loop()
710736

737+
# TODO: Since we do quorum after commit, there might be a big gap until
738+
# the next allreduce. This increases the chances of nodes failing
739+
# and so the allreduce to fail.
740+
# - We could maybe do a quorum again right before preparing for a fragment
741+
# using `shring_only`. This might make it tricky for new nodes to join
742+
# though.
743+
# - Maintain a sequence number in the state dict that gets bumped at every
744+
# quorum call. Then we can do a quorum right before allreduce and avoid
745+
# doing quorums after commit.
746+
711747
# We need to set make sure `_local_step` is still
712748
# the same across all replicas if `quorum_id` changed.
713749
#

0 commit comments

Comments
 (0)