Skip to content

Disable open PJRT for xal2 if OS env disable flag is true #7769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 30, 2024

Conversation

FanhaiLu1
Copy link
Contributor

Add disable_xla2_PJRT_test to let engineer have flexibility to turn off jax PIRT test. By default, it won't impact existing logic.

Background:

In ray multiple cases, call jax devices both in head and workers caused exhausted TPU exception. There are 2 ways to fix it:

  1. xla2 init jax function part to a function and don't call it in init state, call the function after all the call inited
  2. have a flag to disable jax function call in jax init.py

@FanhaiLu1
Copy link
Contributor Author

@qihqi Can you take a look and approve this PR?

@qihqi qihqi self-requested a review July 30, 2024 18:11
@qihqi qihqi merged commit fb2d4e1 into pytorch:master Jul 30, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants