@@ -257,17 +257,32 @@ def restore_parameters(self) -> None:
257
257
else :
258
258
p .data .copy_ (self .original_parameters [name ], non_blocking = False )
259
259
260
+ def _save_grads (self ) -> None :
261
+ with torch .no_grad ():
262
+ for name , p in self ._model_fragment .named_parameters ():
263
+ local_param = extract_local_tensor (p .data )
264
+ pseudogradient = local_param - self .original_parameters [name ].to (
265
+ p .device
266
+ )
267
+ self ._grads [name ] = pseudogradient
268
+
260
269
def _set_grads (self ) -> None :
261
270
"""
262
271
Sets the gradients of the model fragment from the allreduce result
263
272
"""
264
- for name , p in self ._model_fragment .named_parameters ():
265
- if isinstance (p , DTensor ):
266
- p .grad ._local_tensor = self ._grads [name ]
267
- else :
268
- p .grad = self ._grads [name ]
269
-
270
- del self ._grads [name ]
273
+ with torch .no_grad ():
274
+ for name , p in self ._model_fragment .named_parameters ():
275
+ # avoid copying the gradient, it should be on the same device
276
+ if isinstance (p , DTensor ):
277
+ p .grad = DTensor .from_local (
278
+ self ._grads [name ],
279
+ p .device_mesh ,
280
+ p .placements ,
281
+ shape = p .shape ,
282
+ stride = p .stride (),
283
+ )
284
+ else :
285
+ p .grad = self ._grads [name ]
271
286
272
287
@torch .profiler .record_function ("torchft::local_sgd::wait" )
273
288
def wait (self ) -> None :
@@ -304,14 +319,9 @@ def prepare_sync(self) -> None:
304
319
Calculate the pseugradient, average them across the manager group and starts
305
320
allreduce on the pseudo-gradients but doesn't wait for it to finish.
306
321
"""
307
- # Set the .grad field of each parameter to its pseudogradient
308
- for name , p in self ._model_fragment .named_parameters ():
309
- local_param = extract_local_tensor (p .data )
310
- pseudogradient = local_param - self .original_parameters [name ].to (p .device )
311
- if isinstance (p , DTensor ):
312
- self ._grads [name ] = pseudogradient
313
- else :
314
- self ._grads [name ] = pseudogradient
322
+ self ._save_grads ()
323
+
324
+ assert len (self ._allreduce_futures ) == 0
315
325
316
326
# Make sure tensors are available to `_stream`
317
327
if self ._stream is not None :
@@ -371,18 +381,12 @@ def _allreduce_per_param(self) -> None:
371
381
"""Performs allreduce on each gradient tensor separately (original method)."""
372
382
for name , p in self ._model_fragment .named_parameters ():
373
383
# Perform allreduce on the pseudogradients
374
- assert p .grad is not None
375
- if isinstance (p , DTensor ):
376
- work = self ._manager .allreduce (
377
- self ._grads [name ], should_quantize = self .should_quantize
378
- )
379
- else :
380
- work = self ._manager .allreduce (
381
- self ._grads [name ], should_quantize = self .should_quantize
382
- )
384
+ work = self ._manager .allreduce (
385
+ self ._grads [name ], should_quantize = self .should_quantize
386
+ )
383
387
self ._allreduce_futures .append (work )
384
388
385
- def bucketize_and_allreduce (
389
+ def _bucketize_and_allreduce (
386
390
self ,
387
391
tensors : List [torch .Tensor ],
388
392
bucket_size_bytes : int ,
@@ -439,10 +443,9 @@ def _allreduce_bucketized(self) -> None:
439
443
"""
440
444
Averages gradients using bucketized allreduce with a fixed buffer.
441
445
"""
442
- grads = [
443
- p .grad for p in self ._model_fragment .parameters () if p .grad is not None
444
- ]
445
- self .bucketize_and_allreduce (
446
+ grads = list (self ._grads .values ())
447
+ assert len (grads ) > 0 , "No gradients to allreduce"
448
+ self ._bucketize_and_allreduce (
446
449
grads ,
447
450
bucket_size_bytes = self .bucket_cap_mb ,
448
451
)
0 commit comments