Skip to content

Commit bd6ea3f

Browse files
committed
[compile] Fix graphbreaks in moe split; scale_grad
ghstack-source-id: a4efe51 Pull Request resolved: #2771
1 parent 86f148b commit bd6ea3f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

recipes/full_finetune_distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ def setup(self, cfg: DictConfig) -> None:
342342
self._compile_loss = compile.get("loss", True)
343343
self._compile_optimizer_step = compile.get("optimizer_step", False)
344344
self._compile_scale_grads = compile.get("scale_grads", True)
345+
if self._compile_model:
346+
# Enable capture_scalar_outputs to compile non-grouped-mm path of moe, that uses split
347+
torch._dynamo.config.capture_scalar_outputs = True
345348

346349
# This indirection is needed to apply torch.compile to scale_grads step.
347350
self._grad_scaler = training.scale_grads_
@@ -941,7 +944,7 @@ def train(self) -> None:
941944

942945
# Manually scale the gradients from unnormalized loss by total # of tokens
943946
self._grad_scaler(
944-
self._model.parameters(),
947+
list(self._model.parameters()),
945948
self.world_size / num_tokens,
946949
False if self.parallel_dims.tp_enabled else None,
947950
)

0 commit comments

Comments
 (0)