Skip to content

Implement collective gather op #9435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
52 changes: 52 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,58 @@ def test_scatter(self):
np.testing.assert_array_equal(value, [ordinal])

@staticmethod
def _gather(scalar: bool = False):
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()

# If scalar, tensors are tensor(i). Otherwise they are tensor([i]).
# The two cases follow different and should be tested separately.
if scalar:
item = xr.global_ordinal()
dummy = -1.0
else:
item = [xr.global_ordinal()]
dummy = [-1.0]

tensor = torch.tensor(item, device=device, dtype=torch.float)

# Instantiate tensors on device 0 to receive the results
output_tensors = None
if xr.global_ordinal() == 0:
output_tensors = [
torch.tensor(dummy, device=device, dtype=torch.float)
for _ in range(world_size)
]

dist.gather(tensor, output_tensors, dst=0)
if not output_tensors:
return None
else:
return [t.cpu() for t in output_tensors]

@parameterized.named_parameters(('scalar', True), ('tensor', False))
def test_gather(self, scalar):
# self._gather instantiates tensor i or [i], depending on the value of
# `scalar`, on device i. The results are gathered on device 0.
# All other devices get None.
results = pjrt.run_multiprocess(self._gather, scalar)
if scalar:
expected = [
torch.tensor(i, dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
else:
expected = [
torch.tensor([i], dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
for ordinal, value in results.items():
if ordinal == 0:
torch.testing.assert_close(value, expected)
else:
assert value is None

def _reduce():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
Expand Down
1 change: 0 additions & 1 deletion test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def test_barrier(self):
@parameterized.parameters(
'allreduce_coalesced',
'alltoall',
'gather',
'recv_anysource',
'monitored_barrier',
)
Expand Down
44 changes: 40 additions & 4 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import torch_xla.runtime as xr
from torch_xla._internal import rendezvous
import logging
import os
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, ReduceOptions
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, GatherOptions, ReduceOptions


def _create_xla_process_group(prefix_store, rank, size, timeout):
Expand Down Expand Up @@ -264,8 +263,45 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
output.copy_(result)
return _ret_work(output)

def gather(self, *args):
raise NotImplementedError
# Called by torch.distributed.gather. Call site example:
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4043
# Input tensors are gathered into list of output tensors on the dst device.
# Output tensors list is None for all non-dst devices.
# This is an inefficient operation. In order to avoid XLA deadlocks it
# performs redundant gathers on non-dst devices and materializes the result.
def gather(self, output_tensors_list: list[list[torch.Tensor]],
input_tensor_list: list[torch.Tensor], opts: GatherOptions):
rank = xr.global_ordinal()

for i, input_tensor in enumerate(input_tensor_list):
is_scalar = input_tensor.dim() == 0
input_for_all_gather = (
input_tensor.clone().reshape(1) if is_scalar else input_tensor)

gathered_tensor = xm.all_gather(
input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False)

# Syncing is required to keep the heterogeneous copying below at the
# Python layer, avoiding deadlocks due to mismatched HLO.
torch_xla.sync()

if rank == opts.rootRank:
output_tensors = output_tensors_list[i]
if is_scalar:
for j in range(xr.world_size()):
output_tensors[j].copy_(gathered_tensor[j])
else:
chunk_size = input_tensor.shape[0]
gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0)
for j, chunk in enumerate(gathered_chunks):
if chunk.shape != output_tensors[j].shape:
chunk = chunk.reshape(output_tensors[j].shape)
output_tensors[j].copy_(chunk)

if rank == opts.rootRank:
return _ret_work(output_tensors_list)
else:
return _ret_work([[]])

# Called by torch.distributed.scatter. Call site example:
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146
Expand Down
Loading