Skip to content

Commit 323fb47

Browse files
committed
enable merging parameters for diloco
1 parent f0fa70b commit 323fb47

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

torchft/local_sgd.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,14 @@ def __init__(
213213
self.should_quantize = should_quantize
214214

215215
self._grads: Dict[str, torch.Tensor] = {}
216+
217+
# Used to save global parameters so that they can be restored in case
218+
# commit fails
216219
self.original_parameters: Dict[str, torch.Tensor] = {}
217220

221+
# Used to mix the local and global parameters
222+
self._local_parameters: Dict[str, torch.Tensor] = {}
223+
218224
for name, p in self._model_fragment.named_parameters():
219225
if isinstance(p, DTensor):
220226
p = extract_local_tensor(p.data)
@@ -237,6 +243,14 @@ def save_parameters(self) -> None:
237243
param_to_local = extract_local_tensor(p.data)
238244
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
239245

246+
def _save_local_parameters(self) -> None:
247+
"""
248+
Saves a copy of the model's parameters.
249+
"""
250+
with torch.no_grad():
251+
for name, p in self._model_fragment.named_parameters():
252+
self._local_parameters[name] = extract_local_tensor(p.data)
253+
240254
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
241255
def restore_parameters(self) -> None:
242256
with torch.no_grad():
@@ -282,6 +296,21 @@ def _set_grads(self) -> None:
282296
else:
283297
p.grad = self._grads[name]
284298

299+
def _clear_local_parameters(self) -> None:
300+
"""
301+
Clears the saved copy of the model's parameters
302+
"""
303+
self._local_parameters = {}
304+
305+
def _merge_parameters(self) -> None:
306+
"""
307+
Merges the local and global parameters.
308+
"""
309+
for name, p in self._model_fragment.named_parameters():
310+
torch.lerp(
311+
p.data, self._local_parameters[name], 1 - self._fragment_update_alpha
312+
)
313+
285314
@torch.profiler.record_function("torchft::local_sgd::wait")
286315
def wait(self) -> None:
287316
"""
@@ -350,6 +379,8 @@ def perform_sync(self) -> bool:
350379

351380
self.wait()
352381

382+
# save the parameters so they can be used for merging
383+
self._save_local_parameters()
353384
# Restore the parameters back to the previous state
354385
self.restore_parameters()
355386

@@ -360,8 +391,12 @@ def perform_sync(self) -> bool:
360391
self._set_grads()
361392
self._outer_optimizer.step()
362393
self.save_parameters()
394+
self._merge_parameters()
363395
self._outer_optimizer.zero_grad()
364396

397+
# free up memory
398+
self._clear_local_parameters()
399+
365400
return should_commit
366401

367402
def _average_grads(self) -> None:
@@ -513,12 +548,6 @@ def __init__(
513548
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
514549
raise ValueError("fragment_update_alpha must be between 0 and 1")
515550

516-
# TODO: Support `fragment_update_alpha`
517-
if fragment_update_alpha != 0.0:
518-
raise ValueError(
519-
"Merging local parameters with global parameters is not supported yet"
520-
)
521-
522551
super().__init__()
523552
self._manager = manager
524553

0 commit comments

Comments
 (0)