diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index a9681fe5e06..1afe859a204 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -389,6 +389,59 @@ def test_scatter(self): for ordinal, value in results.items(): np.testing.assert_array_equal(value, [ordinal]) + @staticmethod + def _gather(scalar: bool = False): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + + # If scalar, tensors are tensor(i). Otherwise they are tensor([i]). + # The two cases follow different and should be tested separately. + if scalar: + item = xr.global_ordinal() + dummy = -1.0 + else: + item = [xr.global_ordinal()] + dummy = [-1.0] + + tensor = torch.tensor(item, device=device, dtype=torch.float) + + # Instantiate tensors on device 0 to receive the results + output_tensors = None + if xr.global_ordinal() == 0: + output_tensors = [ + torch.tensor(dummy, device=device, dtype=torch.float) + for _ in range(world_size) + ] + + dist.gather(tensor, output_tensors, dst=0) + if not output_tensors: + return None + else: + return [t.cpu() for t in output_tensors] + + @parameterized.named_parameters(('scalar', True), ('tensor', False)) + def test_gather(self, scalar): + # self._gather instantiates tensor i or [i], depending on the value of + # `scalar`, on device i. The results are gathered on device 0. + # All other devices get None. + results = pjrt.run_multiprocess(self._gather, scalar) + if scalar: + expected = [ + torch.tensor(i, dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] + else: + expected = [ + torch.tensor([i], dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] + for ordinal, value in results.items(): + if ordinal == 0: + torch.testing.assert_close(value, expected) + else: + assert value is None + @staticmethod def _reduce(): dist.init_process_group("xla", init_method='xla://') diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 867226b5451..3af6aaa8a08 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -357,7 +357,6 @@ def test_barrier(self): @parameterized.parameters( 'allreduce_coalesced', - 'gather', 'recv_anysource', 'monitored_barrier', ) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index c0dd3104cf0..87ed4bbd7a5 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -6,7 +6,7 @@ from torch_xla._internal import rendezvous import logging import os -from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions, GatherOptions def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -280,8 +280,45 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, output.copy_(result) return _ret_work(output) - def gather(self, *args): - raise NotImplementedError + # Called by torch.distributed.gather. Call site example: + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4043 + # Input tensors are gathered into list of output tensors on the dst device. + # Output tensors list is None for all non-dst devices. + # This is an inefficient operation. In order to avoid XLA deadlocks it + # performs redundant gathers on non-dst devices and materializes the result. + def gather(self, output_tensors_list: list[list[torch.Tensor]], + input_tensor_list: list[torch.Tensor], opts: GatherOptions): + rank = xr.global_ordinal() + + for i, input_tensor in enumerate(input_tensor_list): + is_scalar = input_tensor.dim() == 0 + input_for_all_gather = ( + input_tensor.clone().reshape(1) if is_scalar else input_tensor) + + gathered_tensor = xm.all_gather( + input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False) + + # Syncing is required to keep the heterogeneous copying below at the + # Python layer, avoiding deadlocks due to mismatched HLO. + torch_xla.sync() + + if rank == opts.rootRank: + output_tensors = output_tensors_list[i] + if is_scalar: + for j in range(xr.world_size()): + output_tensors[j].copy_(gathered_tensor[j]) + else: + chunk_size = input_tensor.shape[0] + gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0) + for j, chunk in enumerate(gathered_chunks): + if chunk.shape != output_tensors[j].shape: + chunk = chunk.reshape(output_tensors[j].shape) + output_tensors[j].copy_(chunk) + + if rank == opts.rootRank: + return _ret_work(output_tensors_list) + else: + return _ret_work([[]]) # Called by torch.distributed.scatter. Call site example: # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146