Skip to content

Commit 7065eb9

Browse files
committed
[compile] Fix graphbreaks in moe split; scale_grad
ghstack-source-id: 287e263 Pull Request resolved: #2771
1 parent 86f148b commit 7065eb9

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

recipes/full_finetune_distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ def setup(self, cfg: DictConfig) -> None:
342342
self._compile_loss = compile.get("loss", True)
343343
self._compile_optimizer_step = compile.get("optimizer_step", False)
344344
self._compile_scale_grads = compile.get("scale_grads", True)
345+
if self._compile_model:
346+
# Capture scalar outputs is required to compile MoE
347+
torch._dynamo.config.capture_scalar_outputs = True
345348

346349
# This indirection is needed to apply torch.compile to scale_grads step.
347350
self._grad_scaler = training.scale_grads_
@@ -941,7 +944,7 @@ def train(self) -> None:
941944

942945
# Manually scale the gradients from unnormalized loss by total # of tokens
943946
self._grad_scaler(
944-
self._model.parameters(),
947+
list(self._model.parameters()),
945948
self.world_size / num_tokens,
946949
False if self.parallel_dims.tp_enabled else None,
947950
)

torchtune/modules/attention_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,12 @@ def _attention_call(
209209
# This will use flash attention under the hood with support for custom masks.
210210
# Currently, it is used when sample packing is enabled (see torchtune.datasets.PackedDataset)
211211
if isinstance(mask, BlockMask):
212-
log_once(
213-
_log,
214-
"Using flex attention for attention computation since a BlockMask was passed in.",
215-
level=logging.DEBUG,
216-
)
212+
if not torch.compiler.is_compiling():
213+
log_once(
214+
_log,
215+
"Using flex attention for attention computation since a BlockMask was passed in.",
216+
level=logging.DEBUG,
217+
)
217218
if dropout_p > 0.0:
218219
raise ValueError(
219220
"Flex attention does not support dropout. Please set dropout to 0.0."

torchtune/modules/moe/experts.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def reset_parameters(self) -> None:
5050
# TODO: force no inference mode as a hack to get around
5151
# "Cannot set version_counter for inference tensor"
5252
@torch.inference_mode(mode=False)
53-
# TODO: remove once compilation is fixed
54-
@torch._dynamo.disable(recursive=False)
5553
def forward(
5654
self,
5755
x: torch.Tensor,

0 commit comments

Comments
 (0)