Skip to content

feat: TensorRT AOT Plugin #3504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 44 additions & 56 deletions examples/dynamo/aot_plugin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import argparse
from typing import Tuple, Union


import tensorrt as trt
import tensorrt.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl


trt_logger = trt.Logger(trt.Logger.VERBOSE)


Expand All @@ -25,9 +23,7 @@ def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):


@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
def add_one(
X: torch.Tensor
) -> torch.Tensor:
def add_one(X: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda

Expand All @@ -51,63 +47,58 @@ def _(X: torch.Tensor) -> torch.Tensor:
return X


# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
# "my::add_one"
# )

@trtp.register("my::add_one")
def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
return X.like()

@trtp.aot_impl("my::add_one")
def add_plugin_aot_impl(
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:


type_str = "fp32" if X.dtype == trt.float32 else "fp16"

block_size = 256
src = triton.compiler.ASTSource(
fn=add_one_kernel,
signature={
"x_ptr": f"*{type_str}",
"n_elements": "i32",
"y_ptr": f"*{type_str}",
"BLOCK_SIZE": "constexpr",
},
constants={
"BLOCK_SIZE": block_size,
},
)

compiled_kernel = triton.compile(src)

N = X.shape_expr.numel()
launch_params = trtp.KernelLaunchParams()

# grid dims
launch_params.grid_x = trtp.cdiv(N, block_size)
# block dims
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
# shared memory
launch_params.shared_mem = compiled_kernel.metadata.shared

extra_args = trtp.SymIntExprs(1)
extra_args[0] = trtp.SymInt32(N)

return (
compiled_kernel.metadata.name,
compiled_kernel.asm["ptx"],
launch_params,
extra_args,
)
# @trtp.aot_impl("my::add_one")
# def add_plugin_aot_impl(
# X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
# ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
# type_str = "fp32" if X.dtype == trt.float32 else "fp16"

# block_size = 256
# src = triton.compiler.ASTSource(
# fn=add_one_kernel,
# signature={
# "x_ptr": f"*{type_str}",
# "n_elements": "i32",
# "y_ptr": f"*{type_str}",
# "BLOCK_SIZE": "constexpr",
# },
# constants={
# "BLOCK_SIZE": block_size,
# },
# )

# compiled_kernel = triton.compile(src)

# N = X.shape_expr.numel()
# launch_params = trtp.KernelLaunchParams()

# # grid dims
# launch_params.grid_x = trtp.cdiv(N, block_size)
# # block dims
# launch_params.block_x = compiled_kernel.metadata.num_warps * 32
# # shared memory
# launch_params.shared_mem = compiled_kernel.metadata.shared

# extra_args = trtp.SymIntExprs(1)
# extra_args[0] = trtp.SymInt32(N)

# return (
# compiled_kernel.metadata.name,
# compiled_kernel.asm["ptx"],
# launch_params,
# extra_args,
# )

torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"my::add_one",
supports_dynamic_shapes=False,
requires_output_allocator=False,
aot=True,
use_aot_if_available=True,
)


Expand All @@ -129,15 +120,12 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
)
args = parser.parse_args()



my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)

# This works!
assert my_model(X=m)[0][0] == 3.0


with torch_tensorrt.logging.debug():
trt_inputs = [m]
model_trt = torch_tensorrt.compile(
Expand All @@ -153,4 +141,4 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
assert torch.allclose(res, my_model(m)), "Results do not match!"

print("Inference successful!")
print(res)
print(res)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _generate_plugin_converter(
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
aot: bool = False,
use_aot_if_available: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default to true

) -> DynamoConverterImplSignature:
torch_target = getattr(getattr(torch.ops, namespace), op_name)
overload_str = overload if overload else ""
Expand All @@ -42,6 +42,16 @@ def _generate_plugin_converter(
), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter"
torch_schema = torch_target._schemas[overload_str]

use_aot_plugin = use_aot_if_available

if use_aot_if_available:
desc = QDP_REGISTRY[f"{namespace}::{op_name}"]
if desc.aot_impl_func is None:
use_aot_plugin = False
_LOGGER.debug(
f"AOT impl func not found for {namespace}::{op_name}, use JIT plugin instead"
)

def custom_kernel_converter(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -81,7 +91,7 @@ def custom_kernel_converter(
if isinstance(v, torch.fx.immutable_collections.immutable_list):
kwargs[k] = np.array(v)

layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=aot)
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=use_aot_plugin)
assert layer, f"{namespace}::{name} plugin layer was not able to be created"
_LOGGER.debug(
f"Adding generated plugin for {namespace}::{name} to tensorrt network"
Expand All @@ -108,7 +118,7 @@ def generate_plugin_converter(
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
aot: bool = False,
use_aot_if_available: bool = False,
) -> DynamoConverterImplSignature:
plugin_ns, plugin_name = plugin_id.split("::")
return _generate_plugin_converter(
Expand All @@ -118,5 +128,5 @@ def generate_plugin_converter(
priority=priority,
supports_dynamic_shapes=supports_dynamic_shapes,
requires_output_allocator=requires_output_allocator,
aot=aot,
use_aot_if_available=use_aot_if_available,
)
Loading