diff --git a/examples/dynamo/aot_plugin.py b/examples/dynamo/aot_plugin.py new file mode 100644 index 0000000000..7e8204c165 --- /dev/null +++ b/examples/dynamo/aot_plugin.py @@ -0,0 +1,144 @@ +import argparse +from typing import Tuple, Union + +import tensorrt as trt +import tensorrt.plugin as trtp +import torch +import torch_tensorrt +import triton +import triton.language as tl + +trt_logger = trt.Logger(trt.Logger.VERBOSE) + + +@triton.jit +def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = x + 1 + tl.store(y_ptr + offsets, output, mask=mask) + + +@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc] +def add_one(X: torch.Tensor) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda + + # Create output tensor + Y = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 256 + + # Grid of programs + grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) + + # Launch the kernel + add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE) + + return Y + + +@torch.library.register_fake("my::add_one") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +@trtp.register("my::add_one") +def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: + return X.like() + + +@trtp.aot_impl("my::add_one") +def add_plugin_aot_impl( + X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int +) -> Tuple[ + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs +]: + type_str = "fp32" if X.dtype == trt.float32 else "fp16" + + block_size = 256 + src = triton.compiler.ASTSource( + fn=add_one_kernel, + signature={ + "x_ptr": f"*{type_str}", + "n_elements": "i32", + "y_ptr": f"*{type_str}", + "BLOCK_SIZE": "constexpr", + }, + constants={ + "BLOCK_SIZE": block_size, + }, + ) + + compiled_kernel = triton.compile(src) + + N = X.shape_expr.numel() + launch_params = trtp.KernelLaunchParams() + + # grid dims + launch_params.grid_x = trtp.cdiv(N, block_size) + # block dims + launch_params.block_x = compiled_kernel.metadata.num_warps * 32 + # shared memory + launch_params.shared_mem = compiled_kernel.metadata.shared + + extra_args = trtp.SymIntExprs(1) + extra_args[0] = trtp.SymInt32(N) + + return ( + compiled_kernel.metadata.name, + compiled_kernel.asm["ptx"], + launch_params, + extra_args, + ) + + +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( + "my::add_one", + supports_dynamic_shapes=False, + requires_output_allocator=False, + use_aot_if_available=True, +) + + +class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, X: torch.Tensor) -> torch.Tensor: + res = torch.ops.my.add_one.default(X) + + return res + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--aot", action="store_true", help="Try to use AOT compilation", default=False + ) + args = parser.parse_args() + + my_model = MyModel().to("cuda") + m = torch.full((64, 64), 2, device="cuda", dtype=torch.float) + + assert my_model(X=m)[0][0] == 3.0 + + with torch_tensorrt.logging.debug(): + trt_inputs = [m] + model_trt = torch_tensorrt.compile( + my_model, + inputs=trt_inputs, + debug=True, + min_block_size=1, + ) + print("Model compiled successfully!") + print("Running inference with compiled model...") + for i in range(10): + res = model_trt(m) + assert torch.allclose(res, my_model(m)), "Results do not match!" + + print("Inference successful!") diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 99ea3bc356..3ab8a90588 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -31,6 +31,7 @@ def _generate_plugin_converter( priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, + use_aot_if_available: bool = False, ) -> DynamoConverterImplSignature: torch_target = getattr(getattr(torch.ops, namespace), op_name) overload_str = overload if overload else "" @@ -41,6 +42,16 @@ def _generate_plugin_converter( ), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter" torch_schema = torch_target._schemas[overload_str] + use_aot_plugin = use_aot_if_available + + if use_aot_if_available: + desc = QDP_REGISTRY[f"{namespace}::{op_name}"] + if desc.aot_impl_func is None: + use_aot_plugin = False + _LOGGER.debug( + f"AOT impl func not found for {namespace}::{op_name}, use JIT plugin instead" + ) + def custom_kernel_converter( ctx: ConversionContext, target: Target, @@ -80,7 +91,7 @@ def custom_kernel_converter( if isinstance(v, torch.fx.immutable_collections.immutable_list): kwargs[k] = np.array(v) - layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) + layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=use_aot_plugin) assert layer, f"{namespace}::{name} plugin layer was not able to be created" _LOGGER.debug( f"Adding generated plugin for {namespace}::{name} to tensorrt network" @@ -107,6 +118,7 @@ def generate_plugin_converter( priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, + use_aot_if_available: bool = False, ) -> DynamoConverterImplSignature: plugin_ns, plugin_name = plugin_id.split("::") return _generate_plugin_converter( @@ -116,4 +128,5 @@ def generate_plugin_converter( priority=priority, supports_dynamic_shapes=supports_dynamic_shapes, requires_output_allocator=requires_output_allocator, + use_aot_if_available=use_aot_if_available, )