Skip to content

Commit ad0d786

Browse files
authored
fix: convert_module_to_trt_engine (#2728)
1 parent e0b7ce7 commit ad0d786

File tree

4 files changed

+41
-37
lines changed

4 files changed

+41
-37
lines changed

docsrc/py_api/dynamo.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Functions
2222

2323
.. autofunction:: export
2424

25+
.. autofunction:: convert_module_to_trt_engine
26+
2527

2628

2729
Classes

py/torch_tensorrt/_compile.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import collections.abc
34
import logging
45
from enum import Enum
56
from typing import Any, Callable, List, Optional, Sequence, Set
@@ -240,8 +241,6 @@ def compile(
240241
return compiled_fx_module
241242
elif target_ir == _IRType.dynamo:
242243
# Prepare torch and torchtrt inputs
243-
import collections.abc
244-
245244
from torch_tensorrt.dynamo.utils import prepare_inputs
246245

247246
if not isinstance(input_list, collections.abc.Sequence):
@@ -345,10 +344,19 @@ def convert_method_to_trt_engine(
345344
"convert_method_to_trt_engine call is not supported for ir=fx"
346345
)
347346
elif target_ir == _IRType.dynamo:
347+
# Prepare torch and torchtrt inputs
348+
from torch_tensorrt.dynamo.utils import prepare_inputs
349+
350+
if not isinstance(inputs, collections.abc.Sequence):
351+
inputs = [inputs]
352+
353+
# Export the module
354+
torchtrt_inputs = prepare_inputs(inputs)
355+
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
356+
348357
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
349-
module,
358+
exp_program,
350359
inputs=inputs,
351-
method_name=method_name,
352360
enabled_precisions=enabled_precisions_set,
353361
**kwargs,
354362
)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,7 @@ def compile_module(
416416

417417

418418
def convert_module_to_trt_engine(
419-
module: torch.fx.GraphModule,
420-
method_name: str = "forward",
419+
exported_program: ExportedProgram,
421420
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
422421
enabled_precisions: (
423422
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
@@ -447,15 +446,15 @@ def convert_module_to_trt_engine(
447446
calibrator: object = None,
448447
allow_shape_tensors: bool = False,
449448
) -> bytes:
450-
"""Convert a GraphModule module method to a serialized TensorRT engine
449+
"""Convert an ExportedProgram to a serialized TensorRT engine
451450
452-
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
451+
Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings
453452
454453
Arguments:
455-
module (torch.fx.GraphModule): Source module
454+
exported_program (torch.export.ExportedProgram): Source module
456455
457456
Keyword Args:
458-
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
457+
inputs (Optional[Sequence[torch_tensorrt.Input | torch.Tensor]]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
459458
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
460459
to select device type. ::
461460
@@ -470,30 +469,11 @@ def convert_module_to_trt_engine(
470469
), # Dynamic input shape for input #2
471470
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
472471
]
473-
474-
method_name (str): Name of method to convert
475-
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
476-
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
477-
478-
input_signature=([
479-
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
480-
torch_tensorrt.Input(
481-
min_shape=(1, 224, 224, 3),
482-
opt_shape=(1, 512, 512, 3),
483-
max_shape=(1, 1024, 1024, 3),
484-
dtype=torch.int32
485-
format=torch.channel_last
486-
), # Dynamic input shape for input #2
487-
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
488-
489-
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
490-
491-
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
492-
472+
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
493473
debug (bool): Whether to print out verbose debugging information
494474
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
495475
min_block_size (int): Minimum number of operators per TRT-Engine Block
496-
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
476+
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
497477
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
498478
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
499479
version_compatible (bool): Provide version forward-compatibility for engine plan files
@@ -560,13 +540,25 @@ def convert_module_to_trt_engine(
560540
"dla_global_dram_size": dla_global_dram_size,
561541
}
562542

543+
# Decompose the exported program
544+
exported_program = exported_program.run_decompositions(
545+
get_decompositions(enable_experimental_decompositions)
546+
)
547+
gm = exported_program.module()
548+
logger.debug("Input graph: " + str(gm.graph))
549+
550+
# Apply lowering on the graph module
551+
torch_inputs = get_torch_inputs(input_list, device)
552+
gm = apply_lowering_passes(gm, torch_inputs)
553+
logger.debug("Lowered Input graph: " + str(gm.graph))
554+
563555
settings = CompilationSettings(**compilation_options)
564556
logger.info("Compilation Settings: %s\n", settings)
565557
try:
566-
interpreter_result = interpret_module_to_result(module, input_list, settings)
558+
interpreter_result = interpret_module_to_result(gm, input_list, settings)
567559
except UnsupportedOperatorException:
568560
logger.error(
569-
f"Conversion of module {module} not currently fully supported or convertible!",
561+
f"Conversion of module {gm} not currently fully supported or convertible!",
570562
exc_info=True,
571563
)
572564
except Exception as e:

tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py renamed to tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
88

99

10-
class TestConvertMethodToTrtEngine(unittest.TestCase):
10+
class TestConvertModuleToTrtEngine(unittest.TestCase):
1111
def test_convert_module(self):
1212
class Test(torch.nn.Module):
1313
def forward(self, a, b):
@@ -18,19 +18,21 @@ def forward(self, a, b):
1818

1919
# Create a model
2020
model = Test()
21-
symbolic_traced_gm = torch.fx.symbolic_trace(model)
21+
exp_program = torch.export.export(model, (input_data_0, input_data_1))
2222

2323
# Convert to TensorRT engine
2424
trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine(
25-
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
25+
exp_program, inputs=(input_data_0, input_data_1)
2626
)
2727

2828
# Deserialize the TensorRT engine
2929
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
3030
engine = runtime.deserialize_cuda_engine(trt_engine_str)
3131

3232
# Inference on TRT Engine
33-
py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"])
33+
py_trt_module = PythonTorchTensorRTModule(
34+
engine, ["arg0_1", "arg1_1"], ["output0"]
35+
)
3436
trt_output = py_trt_module(input_data_0, input_data_1).cpu()
3537

3638
# Inference on PyTorch model

0 commit comments

Comments
 (0)