-
Notifications
You must be signed in to change notification settings - Fork 630
[compile] Fix graphbreaks in moe split; scale_grad #2771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2771
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fe5c81e with merge base 86f148b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2771 +/- ##
==========================================
- Coverage 62.93% 60.10% -2.83%
==========================================
Files 437 437
Lines 26710 26712 +2
==========================================
- Hits 16809 16055 -754
- Misses 9901 10657 +756 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Fixing 2 graphbreaks for compile in 1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True after this moe/experts compiles without problem for flex_attn 2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region [ghstack-poisoned]
Fixing 2 graphbreaks for compile in 1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True after this moe/experts compiles without problem for flex_attn 2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region 3/ log in attention_utils during compilation results in graph_break [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two quick comments but looks good!
if self._compile_model: | ||
# Capture scalar outputs is required to compile MoE | ||
torch._dynamo.config.capture_scalar_outputs = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also want to add this to lora_finetune_distributed.py too (I think the logic should be the same there)
"Using flex attention for attention computation since a BlockMask was passed in.", | ||
level=logging.DEBUG, | ||
) | ||
if not torch.compiler.is_compiling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: why do we only want to log this when we're not compiling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamo graph_breaks on log(), so this is only to avoid the graph break. But it's safe to log in normal non-compiling execution :)
Graph break in user code at /data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py:105
Graph Break Reason: Logger not supported for non-export cases. To avoid graph breaks caused by logger in compile-mode, it is recommended to disable logging by adding logging methods to config.ignore_logger_methods
User code traceback:
File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1204, in <module>
sys.exit(recipe_main())
File "/data/users/ivankobzarev/b/torchtune/torchtune/config/_parse.py", line 99, in wrapper
sys.exit(recipe_main(conf))
File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1199, in recipe_main
recipe.train()
File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1034, in train
current_loss = self._loss_step(batch) * current_num_tokens
File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 927, in _loss_step
outputs = self._model(**batch)
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1873, in _call_impl
return inner()
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1821, in inner
result = forward_call(*args, **kwargs)
File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/model_fusion/_early_fusion.py", line 287, in forward
output = self.decoder(
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/transformer.py", line 661, in forward
h = layer(
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1873, in _call_impl
return inner()
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1821, in inner
result = forward_call(*args, **kwargs)
File "/data/users/ivankobzarev/b/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
return self.checkpoint_fn( # type: ignore[misc]
File "/data/users/ivankobzarev/b/pytorch/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
File "/data/users/ivankobzarev/b/pytorch/torch/utils/checkpoint.py", line 495, in checkpoint
ret = function(*args, **kwargs)
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1765, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/transformer.py", line 134, in forward
attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/attention.py", line 292, in forward
output = self._attention_call(
File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/attention_utils.py", line 214, in _attention_call
log_once(
File "/data/users/ivankobzarev/b/pytorch/torch/_dynamo/polyfills/__init__.py", line 193, in getattr_and_trace
return fn(*args[2:], **kwargs)
File "/data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py", line 55, in log_once
log_rank_zero(logger=logger, msg=msg, level=level)
File "/data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py", line 105, in log_rank_zero
logger.log(level, msg, stacklevel=2)
Fixing 2 graphbreaks for compile in 1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True after this moe/experts compiles without problem for flex_attn 2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region 3/ log in attention_utils during compilation results in graph_break [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Fixing 2 graphbreaks for compile in
1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True
after this moe/experts compiles without problem for flex_attn
2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region
3/ log in attention_utils during compilation results in graph_break