From 5203f38916348d1d83d8361cb8c72c7b4e88a9f2 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Mon, 30 Jun 2025 16:54:23 +0000 Subject: [PATCH 1/8] first attempt at gather, hangs --- test/pjrt/test_collective_ops_tpu.py | 30 ++++++++++++++++++++++++++++ torch_xla/distributed/xla_backend.py | 11 ++++++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 7ee9e7d8a66f..49de96a3db0d 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -126,6 +126,36 @@ def test_scatter(self): for ordinal, value in results.items(): np.testing.assert_array_equal(value, [ordinal]) + @staticmethod + def _gather(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + tensor = torch.tensor([xr.global_ordinal()], + device=device, + dtype=torch.float) + output_tensors = None + if xr.global_ordinal() == 0: + output_tensors = [ + torch.tensor([-1.0], device=device, dtype=torch.float) + for _ in range(world_size) + ] + + dist.gather(tensor, output_tensors, dst=0) + if output_tensors is None: + return None + else: + return [t.cpu() for t in output_tensors] + + def test_gather(self): + results = pjrt.run_multiprocess(self._gather) + expected = list(range(tpu.num_expected_global_devices())) + for ordinal, value in results.items(): + if ordinal == 0: + np.testing.assert_array_equal(value, expected) + else: + assert value is None + @staticmethod def _all_to_all(pin_layout): device = torch_xla.device() diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index daef50c243dc..4ead7ebbb32e 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -5,7 +5,7 @@ from torch_xla._internal import rendezvous import logging import os -from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, GatherOptions def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -250,7 +250,14 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, output.copy_(result) return _ret_work(output) - def gather(self, *args): + def gather(self, output_tensors_list: list[list[torch.Tensor]], input_tensor_list: list[torch.Tensor], opts: GatherOptions): + if xr.global_ordinal() == opts.rootRank: + outputs = output_tensors_list + else: + outputs = [[torch.zeros_like(input_tensor)] * xr.world_size() for input_tensor in input_tensor_list] + return self.allgather(outputs, input_tensor_list) + + raise NotImplementedError # Called by torch.distributed.scatter. Call site example: From e9fe828d950e050fdd232f62f7f6f43fc5e8887e Mon Sep 17 00:00:00 2001 From: bfolie Date: Tue, 1 Jul 2025 21:31:03 +0000 Subject: [PATCH 2/8] get gather working --- test/pjrt/test_collective_ops_tpu.py | 34 ++++++++++++++------ test/test_torch_distributed_xla_backend.py | 1 - torch_xla/distributed/xla_backend.py | 36 +++++++++++++++++----- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 49de96a3db0d..3fcf3652b761 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -127,32 +127,46 @@ def test_scatter(self): np.testing.assert_array_equal(value, [ordinal]) @staticmethod - def _gather(): + def _gather(scalar: bool = False): dist.init_process_group("xla", init_method='xla://') device = torch_xla.device() world_size = xr.world_size() - tensor = torch.tensor([xr.global_ordinal()], - device=device, - dtype=torch.float) + if scalar: + item = xr.global_ordinal() + dummy = -1.0 + else: + item = [xr.global_ordinal()] + dummy = [-1.0] + tensor = torch.tensor(item, device=device, dtype=torch.float) output_tensors = None if xr.global_ordinal() == 0: output_tensors = [ - torch.tensor([-1.0], device=device, dtype=torch.float) + torch.tensor(dummy, device=device, dtype=torch.float) for _ in range(world_size) ] dist.gather(tensor, output_tensors, dst=0) - if output_tensors is None: + if not output_tensors: return None else: return [t.cpu() for t in output_tensors] - def test_gather(self): - results = pjrt.run_multiprocess(self._gather) - expected = list(range(tpu.num_expected_global_devices())) + @parameterized.named_parameters(('scalar', True), ('tensor', False)) + def test_gather(self, scalar): + results = pjrt.run_multiprocess(self._gather, scalar) + if scalar: + expected = [ + torch.tensor(i, dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] + else: + expected = [ + torch.tensor([i], dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] for ordinal, value in results.items(): if ordinal == 0: - np.testing.assert_array_equal(value, expected) + torch.testing.assert_close(value, expected) else: assert value is None diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 99b721a4fa16..098aea5cdf31 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -359,7 +359,6 @@ def test_barrier(self): 'reduce', 'allreduce_coalesced', 'alltoall', - 'gather', 'recv_anysource', 'monitored_barrier', ) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 4ead7ebbb32e..095d784f4d1a 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -1,10 +1,10 @@ import torch import torch.distributed as dist +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import rendezvous import logging -import os from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, GatherOptions @@ -250,15 +250,35 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, output.copy_(result) return _ret_work(output) - def gather(self, output_tensors_list: list[list[torch.Tensor]], input_tensor_list: list[torch.Tensor], opts: GatherOptions): - if xr.global_ordinal() == opts.rootRank: - outputs = output_tensors_list - else: - outputs = [[torch.zeros_like(input_tensor)] * xr.world_size() for input_tensor in input_tensor_list] - return self.allgather(outputs, input_tensor_list) + def gather(self, output_tensors_list: list[list[torch.Tensor]], + input_tensor_list: list[torch.Tensor], opts: GatherOptions): + rank = xr.global_ordinal() + input_tensor = input_tensor_list[0] + is_scalar = input_tensor.dim() == 0 + input_for_all_gather = ( + input_tensor.clone().reshape(1) if is_scalar else input_tensor) + gathered_tensor = xm.all_gather( + input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False) + torch_xla.sync() - raise NotImplementedError + if rank == opts.rootRank: + output_tensors = output_tensors_list[0] + if is_scalar: + for i in range(xr.world_size()): + output_tensors[i].copy_(gathered_tensor[i]) + else: + chunk_size = input_tensor.shape[0] + gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0) + for i, chunk in enumerate(gathered_chunks): + if chunk.shape != output_tensors[i].shape: + chunk = chunk.reshape(output_tensors[i].shape) + output_tensors[i].copy_(chunk) + + if rank == opts.rootRank: + return _ret_work(output_tensors_list[0]) + else: + return _ret_work([]) # Called by torch.distributed.scatter. Call site example: # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146 From 45626874ceb895a4294b93d5ca02d017e14a1dee Mon Sep 17 00:00:00 2001 From: bfolie Date: Tue, 1 Jul 2025 21:45:25 +0000 Subject: [PATCH 3/8] make gather work for coalesced inputs --- test/pjrt/test_collective_ops_tpu.py | 3 ++ torch_xla/distributed/xla_backend.py | 53 ++++++++++++++++------------ 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 3fcf3652b761..2ae994b79040 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -153,6 +153,9 @@ def _gather(scalar: bool = False): @parameterized.named_parameters(('scalar', True), ('tensor', False)) def test_gather(self, scalar): + # self._gather instantiates tensor i or [i], depending on the value of + # `scalar`, on device i. The results are gathered on device 0. + # All other devices get None. results = pjrt.run_multiprocess(self._gather, scalar) if scalar: expected = [ diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 095d784f4d1a..01d6117034b4 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -250,35 +250,44 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, output.copy_(result) return _ret_work(output) + # Called by torch.distributed.gather. Call site example: + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4043 + # Input tensors are gathered into list of output tensors on the dst device. + # Output tensors list is None for all non-dst devices. + # This is an inefficient operation. In order to avoid XLA deadlocks it + # performs redundant gathers on non-dst devices and materializes the result. def gather(self, output_tensors_list: list[list[torch.Tensor]], input_tensor_list: list[torch.Tensor], opts: GatherOptions): rank = xr.global_ordinal() - input_tensor = input_tensor_list[0] - is_scalar = input_tensor.dim() == 0 - input_for_all_gather = ( - input_tensor.clone().reshape(1) if is_scalar else input_tensor) - gathered_tensor = xm.all_gather( - input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False) - torch_xla.sync() + for i, input_tensor in enumerate(input_tensor_list): + is_scalar = input_tensor.dim() == 0 + input_for_all_gather = ( + input_tensor.clone().reshape(1) if is_scalar else input_tensor) + + gathered_tensor = xm.all_gather( + input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False) + # Syncing is required to keep the heterogeneous copying below at the + # Python layer, avoiding deadlocks due to mismatched HLO. + torch_xla.sync() + + if rank == opts.rootRank: + output_tensors = output_tensors_list[i] + if is_scalar: + for j in range(xr.world_size()): + output_tensors[j].copy_(gathered_tensor[j]) + else: + chunk_size = input_tensor.shape[0] + gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0) + for j, chunk in enumerate(gathered_chunks): + if chunk.shape != output_tensors[j].shape: + chunk = chunk.reshape(output_tensors[j].shape) + output_tensors[j].copy_(chunk) if rank == opts.rootRank: - output_tensors = output_tensors_list[0] - if is_scalar: - for i in range(xr.world_size()): - output_tensors[i].copy_(gathered_tensor[i]) - else: - chunk_size = input_tensor.shape[0] - gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0) - for i, chunk in enumerate(gathered_chunks): - if chunk.shape != output_tensors[i].shape: - chunk = chunk.reshape(output_tensors[i].shape) - output_tensors[i].copy_(chunk) - - if rank == opts.rootRank: - return _ret_work(output_tensors_list[0]) + return _ret_work(output_tensors_list) else: - return _ret_work([]) + return _ret_work([[]]) # Called by torch.distributed.scatter. Call site example: # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146 From 8173e7162f7c0231727afa59f6b626609f514312 Mon Sep 17 00:00:00 2001 From: bfolie Date: Tue, 1 Jul 2025 21:50:57 +0000 Subject: [PATCH 4/8] add some more comments to test --- test/pjrt/test_collective_ops_tpu.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 2ae994b79040..80ee4b04b3b4 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -131,13 +131,19 @@ def _gather(scalar: bool = False): dist.init_process_group("xla", init_method='xla://') device = torch_xla.device() world_size = xr.world_size() + + # If scalar, tensors are tensor(i). Otherwise they are tensor([i]). + # The two cases follow different and should be tested separately. if scalar: item = xr.global_ordinal() dummy = -1.0 else: item = [xr.global_ordinal()] dummy = [-1.0] + tensor = torch.tensor(item, device=device, dtype=torch.float) + + # Instantiate tensors on device 0 to receive the results output_tensors = None if xr.global_ordinal() == 0: output_tensors = [ From 39f19084c6a4743dee3f473b7c7ab1c9cb57a40e Mon Sep 17 00:00:00 2001 From: bfolie Date: Tue, 1 Jul 2025 21:54:39 +0000 Subject: [PATCH 5/8] format --- test/pjrt/test_collective_ops_tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 80ee4b04b3b4..a83ba69bdc05 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -140,7 +140,7 @@ def _gather(scalar: bool = False): else: item = [xr.global_ordinal()] dummy = [-1.0] - + tensor = torch.tensor(item, device=device, dtype=torch.float) # Instantiate tensors on device 0 to receive the results From 1a6c3b65d607aeb60ec551bcb253b917f2df6ca2 Mon Sep 17 00:00:00 2001 From: bfolie Date: Tue, 1 Jul 2025 22:29:53 +0000 Subject: [PATCH 6/8] move scatter and gather tests into more appropriate class --- test/pjrt/test_collective_ops_tpu.py | 152 +++++++++++++-------------- 1 file changed, 76 insertions(+), 76 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index a83ba69bdc05..a959473aa54b 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -103,82 +103,6 @@ def test_reduce_scatter(self, pin_layout): for ordinal, value in results.items(): np.testing.assert_array_equal(value, [-ordinal]) - @staticmethod - def _scatter(): - dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() - world_size = xr.world_size() - tensors = None - if xr.global_ordinal() == 0: - tensors = [ - torch.tensor([i], device=device, dtype=torch.float) - for i in range(world_size) - ] - - output_tensor = torch.tensor([-1], dtype=torch.float, device=device) - dist.scatter(output_tensor, tensors, src=0) - return output_tensor.cpu() - - def test_scatter(self): - """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] - on device 0, then scatters it. Device i should therefore receive [i].""" - results = pjrt.run_multiprocess(self._scatter) - for ordinal, value in results.items(): - np.testing.assert_array_equal(value, [ordinal]) - - @staticmethod - def _gather(scalar: bool = False): - dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() - world_size = xr.world_size() - - # If scalar, tensors are tensor(i). Otherwise they are tensor([i]). - # The two cases follow different and should be tested separately. - if scalar: - item = xr.global_ordinal() - dummy = -1.0 - else: - item = [xr.global_ordinal()] - dummy = [-1.0] - - tensor = torch.tensor(item, device=device, dtype=torch.float) - - # Instantiate tensors on device 0 to receive the results - output_tensors = None - if xr.global_ordinal() == 0: - output_tensors = [ - torch.tensor(dummy, device=device, dtype=torch.float) - for _ in range(world_size) - ] - - dist.gather(tensor, output_tensors, dst=0) - if not output_tensors: - return None - else: - return [t.cpu() for t in output_tensors] - - @parameterized.named_parameters(('scalar', True), ('tensor', False)) - def test_gather(self, scalar): - # self._gather instantiates tensor i or [i], depending on the value of - # `scalar`, on device i. The results are gathered on device 0. - # All other devices get None. - results = pjrt.run_multiprocess(self._gather, scalar) - if scalar: - expected = [ - torch.tensor(i, dtype=torch.float) - for i in range(tpu.num_expected_global_devices()) - ] - else: - expected = [ - torch.tensor([i], dtype=torch.float) - for i in range(tpu.num_expected_global_devices()) - ] - for ordinal, value in results.items(): - if ordinal == 0: - torch.testing.assert_close(value, expected) - else: - assert value is None - @staticmethod def _all_to_all(pin_layout): device = torch_xla.device() @@ -412,6 +336,82 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") + @staticmethod + def _scatter(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + tensors = None + if xr.global_ordinal() == 0: + tensors = [ + torch.tensor([i], device=device, dtype=torch.float) + for i in range(world_size) + ] + + output_tensor = torch.tensor([-1], dtype=torch.float, device=device) + dist.scatter(output_tensor, tensors, src=0) + return output_tensor.cpu() + + def test_scatter(self): + """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] + on device 0, then scatters it. Device i should therefore receive [i].""" + results = pjrt.run_multiprocess(self._scatter) + for ordinal, value in results.items(): + np.testing.assert_array_equal(value, [ordinal]) + + @staticmethod + def _gather(scalar: bool = False): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + + # If scalar, tensors are tensor(i). Otherwise they are tensor([i]). + # The two cases follow different and should be tested separately. + if scalar: + item = xr.global_ordinal() + dummy = -1.0 + else: + item = [xr.global_ordinal()] + dummy = [-1.0] + + tensor = torch.tensor(item, device=device, dtype=torch.float) + + # Instantiate tensors on device 0 to receive the results + output_tensors = None + if xr.global_ordinal() == 0: + output_tensors = [ + torch.tensor(dummy, device=device, dtype=torch.float) + for _ in range(world_size) + ] + + dist.gather(tensor, output_tensors, dst=0) + if not output_tensors: + return None + else: + return [t.cpu() for t in output_tensors] + + @parameterized.named_parameters(('scalar', True), ('tensor', False)) + def test_gather(self, scalar): + # self._gather instantiates tensor i or [i], depending on the value of + # `scalar`, on device i. The results are gathered on device 0. + # All other devices get None. + results = pjrt.run_multiprocess(self._gather, scalar) + if scalar: + expected = [ + torch.tensor(i, dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] + else: + expected = [ + torch.tensor([i], dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] + for ordinal, value in results.items(): + if ordinal == 0: + torch.testing.assert_close(value, expected) + else: + assert value is None + if __name__ == '__main__': absltest.main() From 619d4435ada49e9919e0b79a838999c228035178 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Wed, 9 Jul 2025 11:39:58 -0700 Subject: [PATCH 7/8] Add space before comment --- torch_xla/distributed/xla_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 01d6117034b4..52f41edbe340 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -267,6 +267,7 @@ def gather(self, output_tensors_list: list[list[torch.Tensor]], gathered_tensor = xm.all_gather( input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False) + # Syncing is required to keep the heterogeneous copying below at the # Python layer, avoiding deadlocks due to mismatched HLO. torch_xla.sync() From 492e1ee68d04bab0a62c271d4060b1347c43937d Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Thu, 17 Jul 2025 20:51:01 -0700 Subject: [PATCH 8/8] Update test_collective_ops_tpu.py --- test/pjrt/test_collective_ops_tpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 39987577ea63..1afe859a2040 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -441,6 +441,8 @@ def test_gather(self, scalar): torch.testing.assert_close(value, expected) else: assert value is None + + @staticmethod def _reduce(): dist.init_process_group("xla", init_method='xla://') device = torch_xla.device()