Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5dbf2b9

Browse files
authoredMay 24, 2025··
add feature gate for tensorrt plugin (#3518)
1 parent af876d5 commit 5dbf2b9

File tree

4 files changed

+130
-8
lines changed

4 files changed

+130
-8
lines changed
 

‎py/torch_tensorrt/_features.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
import os
23
import sys
34
from collections import namedtuple
@@ -15,6 +16,7 @@
1516
"dynamo_frontend",
1617
"fx_frontend",
1718
"refit",
19+
"qdp_plugin",
1820
],
1921
)
2022

@@ -39,14 +41,24 @@
3941
_FX_FE_AVAIL = True
4042
_REFIT_AVAIL = True
4143

44+
if importlib.util.find_spec("tensorrt.plugin"):
45+
_QDP_PLUGIN_AVAIL = True
46+
else:
47+
_QDP_PLUGIN_AVAIL = False
48+
4249
ENABLED_FEATURES = FeatureSet(
43-
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
50+
_TS_FE_AVAIL,
51+
_TORCHTRT_RT_AVAIL,
52+
_DYNAMO_FE_AVAIL,
53+
_FX_FE_AVAIL,
54+
_REFIT_AVAIL,
55+
_QDP_PLUGIN_AVAIL,
4456
)
4557

4658

4759
def _enabled_features_str() -> str:
4860
enabled = lambda x: "ENABLED" if x else "DISABLED"
49-
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call]
61+
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)}\n" # type: ignore[no-untyped-call]
5062
return out_str
5163

5264

@@ -64,6 +76,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
6476
return wrapper
6577

6678

79+
def needs_qdp_plugin(f: Callable[..., Any]) -> Callable[..., Any]:
80+
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
81+
if ENABLED_FEATURES.qdp_plugin:
82+
return f(*args, **kwargs)
83+
else:
84+
85+
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
86+
raise NotImplementedError(
87+
"TensorRT QDP(Quick Deploy Plugins) not available, requires TensorRT 10.7.0 or higher"
88+
)
89+
90+
return not_implemented(*args, **kwargs)
91+
92+
return wrapper
93+
94+
6795
def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]:
6896
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
6997
if ENABLED_FEATURES.refit:

‎py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
from typing import List, Optional, Sequence
1+
import logging
2+
from typing import List, Optional, Sequence, cast
23

34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
67
from torch_tensorrt.dynamo.conversion.converter_utils import (
8+
get_positive_dim,
79
get_trt_tensor,
810
set_layer_name,
911
)
1012
from torch_tensorrt.dynamo.types import TRTTensor
1113

14+
logger = logging.getLogger(__name__)
15+
1216

1317
def unsqueeze(
1418
ctx: ConversionContext,
@@ -18,12 +22,87 @@ def unsqueeze(
1822
input: TRTTensor,
1923
dim: int,
2024
) -> TRTTensor:
25+
from importlib.metadata import version
26+
27+
if version("tensorrt") < "10.7.0":
28+
logger.warning(
29+
f"IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version: {version('tensorrt')}"
30+
)
31+
return unsqueeze_old(ctx, target, source_ir, name, input, dim)
2132
axes = get_trt_tensor(ctx, dim, f"{name}_axes")
2233
layer = ctx.net.add_unsqueeze(input, axes)
2334
set_layer_name(layer, target, name, source_ir)
2435
return layer.get_output(0)
2536

2637

38+
# old implementation for jetson due to IUnsqueezeLayer was not supported prior to 10.7.0
39+
def unsqueeze_old(
40+
ctx: ConversionContext,
41+
target: Target,
42+
source_ir: Optional[SourceIR],
43+
name: str,
44+
input: TRTTensor,
45+
dim: int,
46+
) -> TRTTensor:
47+
input_val = get_trt_tensor(ctx, input, f"{name}_input")
48+
if not isinstance(input_val, TRTTensor):
49+
raise RuntimeError(
50+
f"unsqueeze received input {input_val} that is not part "
51+
"of the TensorRT region!"
52+
)
53+
54+
dim = cast(int, dim)
55+
56+
input_shape_size = len(input_val.shape)
57+
dim = get_positive_dim(dim, input_shape_size + 1)
58+
59+
intermediate_dim = 0
60+
dynamic_shape_cnt = 0
61+
# if unsqueeze the last dimensions, we can directly append to the shape
62+
if dim == input_shape_size:
63+
intermediate_dim = dim
64+
else:
65+
# since maximum of one dimension is permitted to be specified as -1
66+
# find the intermediate_dim which has only 1 dynamic_shape_cnt
67+
# and then we can add a transpose after reshape if it is not the final shape we want
68+
for i, s in reversed(list(enumerate(input_val.shape))):
69+
if i >= dim:
70+
if s == -1:
71+
dynamic_shape_cnt += 1
72+
if dynamic_shape_cnt > 1:
73+
intermediate_dim = i + 1
74+
break
75+
if i == dim:
76+
intermediate_dim = i
77+
break
78+
# calculate the new_shape for the shuffle layer's reshape_dims
79+
new_shape = list(
80+
tuple(input_val.shape)[:intermediate_dim]
81+
+ (1,)
82+
+ tuple(input_val.shape)[intermediate_dim:]
83+
)
84+
for i, s in enumerate(new_shape):
85+
if i < intermediate_dim and s == -1:
86+
new_shape[i] = 0
87+
layer = ctx.net.add_shuffle(input_val)
88+
layer.reshape_dims = tuple(new_shape)
89+
# if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape
90+
if intermediate_dim != dim:
91+
# calculate the second_transpose for the shuffle layer
92+
permutation = [*range(0, len(new_shape))]
93+
# for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5)
94+
# here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim)
95+
new_permutation = (
96+
tuple(permutation[:dim])
97+
+ (intermediate_dim,)
98+
+ tuple(permutation[dim:intermediate_dim])
99+
+ tuple(permutation[intermediate_dim + 1 :])
100+
)
101+
layer.second_transpose = new_permutation
102+
set_layer_name(layer, target, name, source_ir)
103+
return layer.get_output(0)
104+
105+
27106
def broadcast_in_dim(
28107
ctx: ConversionContext,
29108
target: Target,

‎py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from types import FunctionType
44
from typing import Any, Callable, Tuple
55

6-
import tensorrt.plugin as trtp
76
import torch
87
from sympy import lambdify
98
from torch._dynamo.source import LocalSource
109
from torch._subclasses.fake_tensor import FakeTensorMode
1110
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
11+
from torch_tensorrt._features import needs_qdp_plugin
1212

1313
_LOGGER: logging.Logger = logging.getLogger(__name__)
1414

@@ -28,6 +28,13 @@ def mksym(
2828

2929

3030
def _generate_plugin(plugin_name: str) -> None:
31+
try:
32+
import tensorrt.plugin as trtp
33+
except ImportError as e:
34+
raise RuntimeError(
35+
"Unable to import TensorRT plugin. TensorRT version must be 10.7.0 or higher to support for Triton based TensorRT plugins"
36+
)
37+
3138
namespace, name = plugin_name.split("::")
3239

3340
# retrieve the corresponding torch operation using the passed in string
@@ -211,6 +218,7 @@ def _generic_plugin_impl(
211218
trtp.impl(plugin_name)(plugin_impl)
212219

213220

221+
@needs_qdp_plugin
214222
def generate_plugin(plugin_name: str) -> None:
215223
"""
216224
Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs.

‎py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44

55
import numpy as np
66
import tensorrt as trt
7-
8-
# Seems like a bug in TensorRT
9-
import tensorrt.plugin as trtp
107
import torch
11-
from tensorrt.plugin._lib import QDP_REGISTRY
128
from torch.fx.node import Argument, Node, Target
9+
from torch_tensorrt._features import needs_qdp_plugin
1310
from torch_tensorrt.dynamo._settings import CompilationSettings
1411
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1512
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
@@ -32,6 +29,15 @@ def _generate_plugin_converter(
3229
supports_dynamic_shapes: bool = False,
3330
requires_output_allocator: bool = False,
3431
) -> DynamoConverterImplSignature:
32+
try:
33+
import tensorrt.plugin as trtp
34+
35+
except ImportError as e:
36+
raise RuntimeError(
37+
"Unable to import TensorRT plugin. TensorRT version must be 10.7.0 or higher to support for Triton based TensorRT plugins"
38+
)
39+
from tensorrt.plugin._lib import QDP_REGISTRY
40+
3541
torch_target = getattr(getattr(torch.ops, namespace), op_name)
3642
overload_str = overload if overload else ""
3743
overload_name = overload_str if overload else "default"
@@ -101,6 +107,7 @@ def custom_kernel_converter(
101107
return custom_kernel_converter
102108

103109

110+
@needs_qdp_plugin
104111
def generate_plugin_converter(
105112
plugin_id: str,
106113
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,

0 commit comments

Comments
 (0)
Please sign in to comment.