Skip to content

[compile] Fix graphbreaks in moe split; scale_grad #2771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 30, 2025

Conversation

IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented May 28, 2025

Stack from ghstack (oldest at bottom):

Fixing 2 graphbreaks for compile in
1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True
after this moe/experts compiles without problem for flex_attn

2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region

3/ log in attention_utils during compilation results in graph_break

Copy link

pytorch-bot bot commented May 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2771

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fe5c81e with merge base 86f148b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

IvanKobzarev added a commit that referenced this pull request May 28, 2025
ghstack-source-id: a4efe51
Pull Request resolved: #2771
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 28, 2025
IvanKobzarev added a commit that referenced this pull request May 28, 2025
ghstack-source-id: fa538d8
Pull Request resolved: #2771
@IvanKobzarev IvanKobzarev changed the base branch from gh/IvanKobzarev/8/base to main May 28, 2025 20:06
@codecov-commenter
Copy link

codecov-commenter commented May 28, 2025

Codecov Report

Attention: Patch coverage is 0% with 4 lines in your changes missing coverage. Please review.

Project coverage is 60.10%. Comparing base (86f148b) to head (454868d).

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 2 Missing ⚠️
torchtune/modules/attention_utils.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2771      +/-   ##
==========================================
- Coverage   62.93%   60.10%   -2.83%     
==========================================
  Files         437      437              
  Lines       26710    26712       +2     
==========================================
- Hits        16809    16055     -754     
- Misses       9901    10657     +756     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Fixing 2 graphbreaks for compile in 
1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True
after this moe/experts compiles without problem for flex_attn

2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request May 29, 2025
ghstack-source-id: d8697b6
Pull Request resolved: #2771
Fixing 2 graphbreaks for compile in 
1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True
after this moe/experts compiles without problem for flex_attn

2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region

3/ log in attention_utils during compilation results in graph_break


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request May 29, 2025
ghstack-source-id: 287e263
Pull Request resolved: #2771
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two quick comments but looks good!

Comment on lines +345 to +347
if self._compile_model:
# Capture scalar outputs is required to compile MoE
torch._dynamo.config.capture_scalar_outputs = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want to add this to lora_finetune_distributed.py too (I think the logic should be the same there)

"Using flex attention for attention computation since a BlockMask was passed in.",
level=logging.DEBUG,
)
if not torch.compiler.is_compiling():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: why do we only want to log this when we're not compiling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamo graph_breaks on log(), so this is only to avoid the graph break. But it's safe to log in normal non-compiling execution :)

Graph break in user code at /data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py:105
Graph Break Reason: Logger not supported for non-export cases. To avoid graph breaks caused by logger in compile-mode, it is recommended to disable logging by adding logging methods to config.ignore_logger_methods
User code traceback:
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1204, in <module>
    sys.exit(recipe_main())
  File "/data/users/ivankobzarev/b/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1199, in recipe_main
    recipe.train()
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1034, in train
    current_loss = self._loss_step(batch) * current_num_tokens
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 927, in _loss_step
    outputs = self._model(**batch)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1873, in _call_impl
    return inner()
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1821, in inner
    result = forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/model_fusion/_early_fusion.py", line 287, in forward
    output = self.decoder(
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/transformer.py", line 661, in forward
    h = layer(
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1873, in _call_impl
    return inner()
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1821, in inner
    result = forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/data/users/ivankobzarev/b/pytorch/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/utils/checkpoint.py", line 495, in checkpoint
    ret = function(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1765, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/transformer.py", line 134, in forward
    attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/attention.py", line 292, in forward
    output = self._attention_call(
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/attention_utils.py", line 214, in _attention_call
    log_once(
  File "/data/users/ivankobzarev/b/pytorch/torch/_dynamo/polyfills/__init__.py", line 193, in getattr_and_trace
    return fn(*args[2:], **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py", line 55, in log_once
    log_rank_zero(logger=logger, msg=msg, level=level)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py", line 105, in log_rank_zero
    logger.log(level, msg, stacklevel=2)

Fixing 2 graphbreaks for compile in 
1/ moe/experts do split(tensor.tolist()), .tolist() can be compiled only if to set torch._dynamo.config.capture_scalar_outputs = True
after this moe/experts compiles without problem for flex_attn

2/ model.parameters() produces a generator, that could not be an input to dynamo graph - adding list() call to avoid this graphbreak inside of compiled region

3/ log in attention_utils during compilation results in graph_break


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request May 30, 2025
ghstack-source-id: 3f660b4
Pull Request resolved: #2771
@IvanKobzarev IvanKobzarev merged commit 5ecae86 into main May 30, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants