diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index f3e1b3e1fa..dfdc9e1c69 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -21,6 +21,7 @@ repair_input_aliasing, ) from torch_tensorrt.dynamo.utils import ( + is_tegra_platform, parse_dynamo_kwargs, prepare_inputs, set_log_level, @@ -80,10 +81,14 @@ def aot_torch_tensorrt_aten_backend( fw_compiler=_pretraced_backend_autograd, decompositions=settings_aot_autograd["decompositions"], )(gm, sample_inputs) - if any(isinstance(tensor, DTensor) for tensor in sample_inputs): - logger.warning( - "It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple" - ) + + if is_tegra_platform(): + from torch.distributed.tensor import DTensor + + if any(isinstance(tensor, DTensor) for tensor in sample_inputs): + logger.warning( + "It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple" + ) return _pretraced_backend(gm, sample_inputs, settings, engine_cache) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index fa5eacf7c7..141b68f3e7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -2,7 +2,7 @@ import numpy as np from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.fx.types import TRTNetwork +from torch_tensorrt.dynamo.types import TRTNetwork @dataclass