diff --git a/recipes/configs/llama4/scout_17B_16E_full.yaml b/recipes/configs/llama4/scout_17B_16E_full.yaml index 9b2e923f71..855834241e 100644 --- a/recipes/configs/llama4/scout_17B_16E_full.yaml +++ b/recipes/configs/llama4/scout_17B_16E_full.yaml @@ -69,7 +69,14 @@ device: cuda enable_activation_checkpointing: True enable_activation_offloading: False fsdp_cpu_offload: True -compile: False # torch.compile, set to true for perf/memory improvement +# compile True means use torch.compile for all components +# compile False means no torch.compile +# compile Dictionary with keys: "model", "loss", "optimizer_step" +# enables torch.compile only for specified components. +compile: False +# model: True +# loss: True +# optimizer_step: False # Reduced precision dtype: bf16 diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 58b29eb4be..435afede84 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time @@ -306,7 +307,18 @@ def setup(self, cfg: DictConfig) -> None: # Load the base model checkpoint_dict = self._checkpoint_client.load_base_checkpoint() - self._compile = cfg.get("compile", False) + compile = cfg.get("compile") + compile_bool = bool(compile) + self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + + self._compile_model = compile_bool + self._compile_loss = compile_bool + self._compile_optimizer_step = compile_bool + if isinstance(compile, DictConfig): + self._compile_model = compile.get("model", True) + self._compile_loss = compile.get("loss", True) + self._compile_optimizer_step = compile.get("optimizer_step", False) + self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=self._enable_activation_checkpointing, @@ -329,6 +341,11 @@ def setup(self, cfg: DictConfig) -> None: else None ), ) + if self._compile_optimizer_step: + self._optimizer.step = torch.compile( + self._optimizer.step, + backend=self._compile_backend, + ) if self._resume_from_checkpoint: # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously @@ -358,7 +375,7 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._compile: + if self._compile_loss: training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) # The loss may handle the output projection. If true, the model should skip it. @@ -569,7 +586,7 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) - if self._compile: + if self._compile_model: training.compile_model(model, verbose=self._is_rank_zero) if self._enable_fp8_training: