-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Remove device_count for TPU launcher to avoid initializing runtime #3587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just a nit
if args.num_processes and args.num_processes != device_count(): | ||
raise ValueError( | ||
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are removing that, maybe we should put somewhere that we can running the script on all tpu cores available. -> print("Launching a training on all TPU cores.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, now XLA expects no arguments at all, meaning it will use all TPU cores. If that is correct, the change seems correct.
if args.num_processes and args.num_processes != device_count(): | ||
raise ValueError( | ||
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, now XLA expects no arguments at all, meaning it will use all TPU cores. If that is correct, the change seems correct.
@@ -898,10 +897,6 @@ def tpu_launcher(args): | |||
f"Your training script should have a function named {args.main_training_function}, or you should pass a " | |||
"different value to `--main_training_function`." | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) | |
) | |
logger.warning("Launching training on all TPU cores.") |
Like that? @SunMarc
What does this PR do?
A similar issue to PR #3541 remains for the TPU launcher. With this, we avoid prematurely initializing the XLA runtime.
currently fails. This fixes it. Tested on gcloud v6e TPUs.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@SunMarc @zach-huggingface