diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ff7d3b7a07..d2421eeae9 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -746,6 +746,10 @@ def compile_module( Returns: Compiled FX GraphModule """ + if any(v.requires_grad for v in gm.state_dict().values()): + logger.warning( + "The model may be in training mode, which may affect the performance of the compiled model!" + ) dryrun_tracker = DryRunTracker() if sample_kwarg_inputs is None: sample_kwarg_inputs = {} diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index ba83c89dcd..07ad81d71f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -3,7 +3,7 @@ import warnings from copy import deepcopy from enum import Enum, auto -from typing import Any, Dict, Iterator, Optional, Set, Union +from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union import numpy as np import torch @@ -70,7 +70,9 @@ def __init__( strict: bool = True, allow_complex_guards_as_runtime_asserts: bool = False, weight_streaming_budget: Optional[int] = None, - enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, + enabled_precisions: Union[ + Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] + ] = _defaults.ENABLED_PRECISIONS, **kwargs: Any, ) -> None: """ @@ -128,6 +130,10 @@ def __init__( self.refit_state = RefitState() self.pytorch_model = _make_refit_change_trigger(pytorch_model, self.refit_state) self.original_model = pytorch_model + if pytorch_model.training: + logger.warning( + "The model may be in training mode, which may affect the performance of the compiled model!" + ) # Process settings self.gm: Any = None self.exp_program: Any = None @@ -163,8 +169,6 @@ def __init__( "Weight stremaing budget is not set. Using auto weight streaming budget" ) self.enabled_precisions = enabled_precisions - if self.enabled_precisions is None: - self.enabled_precisions = _defaults.ENABLED_PRECISIONS cls = self.__class__ self.__class__ = type(