diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index a1040680ad..152135fda9 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -342,6 +342,9 @@ def setup(self, cfg: DictConfig) -> None: self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) self._compile_scale_grads = compile.get("scale_grads", True) + if self._compile_model: + # Capture scalar outputs is required to compile MoE + torch._dynamo.config.capture_scalar_outputs = True # This indirection is needed to apply torch.compile to scale_grads step. self._grad_scaler = training.scale_grads_ @@ -941,7 +944,7 @@ def train(self) -> None: # Manually scale the gradients from unnormalized loss by total # of tokens self._grad_scaler( - self._model.parameters(), + list(self._model.parameters()), self.world_size / num_tokens, False if self.parallel_dims.tp_enabled else None, ) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 6bfcd884d7..e8bb616c1e 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -282,6 +282,8 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self._checkpoint_client.load_base_checkpoint() self._compile = cfg.get("compile", False) + # Capture scalar outputs is required to compile MoE + torch._dynamo.config.capture_scalar_outputs = True self._model = self._setup_model( cfg_model=cfg.model, diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index c3c2cc7409..435238025f 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -209,11 +209,12 @@ def _attention_call( # This will use flash attention under the hood with support for custom masks. # Currently, it is used when sample packing is enabled (see torchtune.datasets.PackedDataset) if isinstance(mask, BlockMask): - log_once( - _log, - "Using flex attention for attention computation since a BlockMask was passed in.", - level=logging.DEBUG, - ) + if not torch.compiler.is_compiling(): + log_once( + _log, + "Using flex attention for attention computation since a BlockMask was passed in.", + level=logging.DEBUG, + ) if dropout_p > 0.0: raise ValueError( "Flex attention does not support dropout. Please set dropout to 0.0." diff --git a/torchtune/modules/moe/experts.py b/torchtune/modules/moe/experts.py index 0ec6179213..8b7984c786 100644 --- a/torchtune/modules/moe/experts.py +++ b/torchtune/modules/moe/experts.py @@ -50,8 +50,6 @@ def reset_parameters(self) -> None: # TODO: force no inference mode as a hack to get around # "Cannot set version_counter for inference tensor" @torch.inference_mode(mode=False) - # TODO: remove once compilation is fixed - @torch._dynamo.disable(recursive=False) def forward( self, x: torch.Tensor,