Skip to content

Remove device_count for TPU launcher to avoid initializing runtime #3587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sorgfresser
Copy link

@sorgfresser sorgfresser commented May 24, 2025

What does this PR do?

A similar issue to PR #3541 remains for the TPU launcher. With this, we avoid prematurely initializing the XLA runtime.

accelerate launch <script>

currently fails. This fixes it. Tested on gcloud v6e TPUs.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc @zach-huggingface

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a nit

Comment on lines -901 to -904
if args.num_processes and args.num_processes != device_count():
raise ValueError(
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})"
)
Copy link
Member

@SunMarc SunMarc May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are removing that, maybe we should put somewhere that we can running the script on all tpu cores available. -> print("Launching a training on all TPU cores.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, now XLA expects no arguments at all, meaning it will use all TPU cores. If that is correct, the change seems correct.

@SunMarc SunMarc requested a review from tengomucho May 26, 2025 14:06
Comment on lines -901 to -904
if args.num_processes and args.num_processes != device_count():
raise ValueError(
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, now XLA expects no arguments at all, meaning it will use all TPU cores. If that is correct, the change seems correct.

@@ -898,10 +897,6 @@ def tpu_launcher(args):
f"Your training script should have a function named {args.main_training_function}, or you should pass a "
"different value to `--main_training_function`."
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)
)
logger.warning("Launching training on all TPU cores.")

Like that? @SunMarc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants