Skip to content

Commit 7a04abc

Browse files
committed
support multiple outer optims for diloco
Summary: - support passing in a different outer optimizer for each fragment - currently accept both list of optimizers and a single optimizer for backward compatibility
1 parent 6524d16 commit 7a04abc

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

torchft/local_sgd.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ def __init__(
550550
manager: Manager,
551551
model_fragments: List[nn.Module],
552552
inner_optimizer: optim.Optimizer,
553-
outer_optimizer: optim.Optimizer,
553+
# TODO: this is for backward compatibility
554+
outer_optimizer: optim.Optimizer | list[optim.Optimizer],
554555
sync_every: int,
555556
backup_device: Optional[torch.device] = None,
556557
pin_memory: bool = True,
@@ -575,6 +576,11 @@ def __init__(
575576
fragment_update_alpha: Determines how to mix the local and global optimized parameters
576577
"""
577578

579+
if isinstance(outer_optimizer, list):
580+
assert len(outer_optimizer) == len(
581+
model_fragments
582+
), "The number of outer optimizers must match the number of model fragments"
583+
578584
if manager._use_async_quorum:
579585
raise ValueError(
580586
"Using DiLoCo require synchronous quorum to be enabled. "
@@ -623,8 +629,11 @@ def __init__(
623629
model_fragment,
624630
math.floor((sync_every / len(model_fragments)) * (i + 1)),
625631
inner_optimizer,
626-
# TODO: Support different outer optimizers for each fragment
627-
outer_optimizer,
632+
(
633+
outer_optimizer[i]
634+
if isinstance(outer_optimizer, list)
635+
else outer_optimizer
636+
),
628637
sync_every,
629638
backup_device,
630639
pin_memory,

0 commit comments

Comments
 (0)