Skip to content

Implement collective gather op #9435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open

Conversation

bfolie
Copy link
Collaborator

@bfolie bfolie commented Jul 1, 2025

@bfolie bfolie requested a review from pgmoka July 1, 2025 23:08
Comment on lines 106 to 128
@staticmethod
def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]

output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
dist.scatter(output_tensor, tensors, src=0)
return output_tensor.cpu()

def test_scatter(self):
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
on device 0, then scatters it. Device i should therefore receive [i]."""
results = pjrt.run_multiprocess(self._scatter)
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moving this test into the appropriate class

Comment on lines 339 to 360
@staticmethod
def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]

output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
dist.scatter(output_tensor, tensors, src=0)
return output_tensor.cpu()

def test_scatter(self):
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
on device 0, then scatters it. Device i should therefore receive [i]."""
results = pjrt.run_multiprocess(self._scatter)
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from above

@bfolie
Copy link
Collaborator Author

bfolie commented Jul 2, 2025

Failing tests are expected until the TPU CI cluster is updated to use python 3.12. See #9434

@bfolie bfolie requested a review from benawilson July 2, 2025 19:31
@bfolie bfolie enabled auto-merge (squash) July 9, 2025 20:06
@bfolie bfolie disabled auto-merge July 9, 2025 20:06
@bfolie bfolie requested a review from pgmoka July 11, 2025 18:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants