-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add bf16/fp16 support for amp with mps device #3373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@SunMarc Hi Marc, I have tried to install this version of accelerate and pytorch2.6.0 to use trainer on mps device, but got the following error message, could you please help me check it out? ---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 30
20 trainer = Trainer(
21 model=model,
22 args=training_args,
(...)
27 compute_metrics=compute_metrics
28 )
29 logger.info("Start training")
---> 30 trainer.train()
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/transformers/trainer.py:1885, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1883 hf_hub_utils.enable_progress_bars()
1884 else:
-> 1885 return inner_training_loop(
1886 args=args,
1887 resume_from_checkpoint=resume_from_checkpoint,
1888 trial=trial,
1889 ignore_keys_for_eval=ignore_keys_for_eval,
1890 )
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/transformers/trainer.py:2216, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2213 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
2215 with self.accelerator.accumulate(model):
-> 2216 tr_loss_step = self.training_step(model, inputs)
2218 if (
2219 args.logging_nan_inf_filter
2220 and not is_torch_xla_available()
2221 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2222 ):
2223 # if loss is nan or inf simply add the average of previous logged losses
2224 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/transformers/trainer.py:3250, in Trainer.training_step(***failed resolving arguments***)
3248 scaled_loss.backward()
3249 else:
-> 3250 self.accelerator.backward(loss)
3252 return loss.detach() / self.args.gradient_accumulation_steps
File ~/GitHub/accelerate/src/accelerate/accelerator.py:2250, in Accelerator.backward(self, loss, **kwargs)
2248 self.lomo_backward(loss, learning_rate)
2249 else:
-> 2250 loss.backward(**kwargs)
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/torch/_tensor.py:626, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
616 if has_torch_function_unary(self):
617 return handle_torch_function(
618 Tensor.backward,
619 (self,),
(...)
624 inputs=inputs,
625 )
--> 626 torch.autograd.backward(
627 self, gradient, retain_graph, create_graph, inputs=inputs
628 )
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/torch/autograd/__init__.py:347, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
342 retain_graph = create_graph
344 # The reason we repeat the same comment below is that
345 # some Python versions print out the first line of a multi-line function
346 # calls in the traceback and some print out the last line
--> 347 _engine_run_backward(
348 tensors,
349 grad_tensors_,
350 retain_graph,
351 create_graph,
352 inputs,
353 allow_unreachable=True,
354 accumulate_grad=True,
355 )
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/torch/autograd/graph.py:823, in _engine_run_backward(t_outputs, *args, **kwargs)
821 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
822 try:
--> 823 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
824 t_outputs, *args, **kwargs
825 ) # Calls into the C++ engine to run the backward pass
826 finally:
827 if attach_logging_hooks:
RuntimeError: Expected scalar_type == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int || scalar_type == ScalarType::Bool to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Any update on this now that GradScaler with autocast support has been merged? |
Nice ! Would you like to try training with fp16+mps to see if everything works correctly ? |
What does this PR do?
This PR adds MPS mixed-precision autocast support.
Draft until we get support for GradScaler with autocast. Right now, support for bf16 ops with mps are still a bit limited but pytorch team is working on improving the coverage.
Feel free to test the PR to try bf16 for now