Skip to content

Commit 35437b1

Browse files
committed
feat: enable AOT tensorrt plugin example
1 parent 7ea3638 commit 35437b1

File tree

2 files changed

+158
-1
lines changed

2 files changed

+158
-1
lines changed

examples/dynamo/aot_plugin.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import argparse
2+
from typing import Tuple, Union
3+
4+
import tensorrt as trt
5+
import tensorrt.plugin as trtp
6+
import torch
7+
import torch_tensorrt
8+
import triton
9+
import triton.language as tl
10+
11+
trt_logger = trt.Logger(trt.Logger.VERBOSE)
12+
13+
14+
@triton.jit
15+
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
16+
pid = tl.program_id(0)
17+
block_start = pid * BLOCK_SIZE
18+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
19+
mask = offsets < n_elements
20+
x = tl.load(x_ptr + offsets, mask=mask)
21+
output = x + 1
22+
tl.store(y_ptr + offsets, output, mask=mask)
23+
24+
25+
@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
26+
def add_one(X: torch.Tensor) -> torch.Tensor:
27+
# Ensure the tensors are on the GPU
28+
assert X.is_cuda
29+
30+
# Create output tensor
31+
Y = torch.empty_like(X)
32+
33+
# Define block size
34+
BLOCK_SIZE = 256
35+
36+
# Grid of programs
37+
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)
38+
39+
# Launch the kernel
40+
add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE)
41+
42+
return Y
43+
44+
45+
@torch.library.register_fake("my::add_one")
46+
def _(X: torch.Tensor) -> torch.Tensor:
47+
return X
48+
49+
50+
@trtp.register("my::add_one")
51+
def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
52+
return X.like()
53+
54+
55+
@trtp.aot_impl("my::add_one")
56+
def add_plugin_aot_impl(
57+
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
58+
) -> Tuple[
59+
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
60+
]:
61+
type_str = "fp32" if X.dtype == trt.float32 else "fp16"
62+
63+
block_size = 256
64+
src = triton.compiler.ASTSource(
65+
fn=add_one_kernel,
66+
signature={
67+
"x_ptr": f"*{type_str}",
68+
"n_elements": "i32",
69+
"y_ptr": f"*{type_str}",
70+
"BLOCK_SIZE": "constexpr",
71+
},
72+
constants={
73+
"BLOCK_SIZE": block_size,
74+
},
75+
)
76+
77+
compiled_kernel = triton.compile(src)
78+
79+
N = X.shape_expr.numel()
80+
launch_params = trtp.KernelLaunchParams()
81+
82+
# grid dims
83+
launch_params.grid_x = trtp.cdiv(N, block_size)
84+
# block dims
85+
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
86+
# shared memory
87+
launch_params.shared_mem = compiled_kernel.metadata.shared
88+
89+
extra_args = trtp.SymIntExprs(1)
90+
extra_args[0] = trtp.SymInt32(N)
91+
92+
return (
93+
compiled_kernel.metadata.name,
94+
compiled_kernel.asm["ptx"],
95+
launch_params,
96+
extra_args,
97+
)
98+
99+
100+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
101+
"my::add_one",
102+
supports_dynamic_shapes=False,
103+
requires_output_allocator=False,
104+
use_aot_if_available=True,
105+
)
106+
107+
108+
class MyModel(torch.nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
112+
def forward(self, X: torch.Tensor) -> torch.Tensor:
113+
res = torch.ops.my.add_one.default(X)
114+
115+
return res
116+
117+
118+
if __name__ == "__main__":
119+
parser = argparse.ArgumentParser()
120+
parser.add_argument(
121+
"--aot", action="store_true", help="Try to use AOT compilation", default=False
122+
)
123+
args = parser.parse_args()
124+
125+
my_model = MyModel().to("cuda")
126+
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
127+
128+
assert my_model(X=m)[0][0] == 3.0
129+
130+
with torch_tensorrt.logging.debug():
131+
trt_inputs = [m]
132+
model_trt = torch_tensorrt.compile(
133+
my_model,
134+
inputs=trt_inputs,
135+
debug=True,
136+
min_block_size=1,
137+
)
138+
print("Model compiled successfully!")
139+
print("Running inference with compiled model...")
140+
for i in range(10):
141+
res = model_trt(m)
142+
assert torch.allclose(res, my_model(m)), "Results do not match!"
143+
144+
print("Inference successful!")

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def _generate_plugin_converter(
2828
priority: ConverterPriority = ConverterPriority.STANDARD,
2929
supports_dynamic_shapes: bool = False,
3030
requires_output_allocator: bool = False,
31+
use_aot_if_available: bool = True,
3132
) -> DynamoConverterImplSignature:
3233
try:
3334
import tensorrt.plugin as trtp
@@ -47,6 +48,16 @@ def _generate_plugin_converter(
4748
), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter"
4849
torch_schema = torch_target._schemas[overload_str]
4950

51+
use_aot_plugin = use_aot_if_available
52+
53+
if use_aot_if_available:
54+
desc = QDP_REGISTRY[f"{namespace}::{op_name}"]
55+
if desc.aot_impl_func is None:
56+
use_aot_plugin = False
57+
_LOGGER.debug(
58+
f"AOT impl func not found for {namespace}::{op_name}, use JIT plugin instead"
59+
)
60+
5061
def custom_kernel_converter(
5162
ctx: ConversionContext,
5263
target: Target,
@@ -86,7 +97,7 @@ def custom_kernel_converter(
8697
if isinstance(v, torch.fx.immutable_collections.immutable_list):
8798
kwargs[k] = np.array(v)
8899

89-
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs))
100+
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=use_aot_plugin)
90101
assert layer, f"{namespace}::{name} plugin layer was not able to be created"
91102
_LOGGER.debug(
92103
f"Adding generated plugin for {namespace}::{name} to tensorrt network"
@@ -114,6 +125,7 @@ def generate_plugin_converter(
114125
priority: ConverterPriority = ConverterPriority.STANDARD,
115126
supports_dynamic_shapes: bool = False,
116127
requires_output_allocator: bool = False,
128+
use_aot_if_available: bool = True,
117129
) -> DynamoConverterImplSignature:
118130
plugin_ns, plugin_name = plugin_id.split("::")
119131
return _generate_plugin_converter(
@@ -123,4 +135,5 @@ def generate_plugin_converter(
123135
priority=priority,
124136
supports_dynamic_shapes=supports_dynamic_shapes,
125137
requires_output_allocator=requires_output_allocator,
138+
use_aot_if_available=use_aot_if_available,
126139
)

0 commit comments

Comments
 (0)