2
2
3
3
import collections .abc
4
4
import logging
5
+ import os
5
6
import platform
6
7
import warnings
7
8
from typing import Any , Collection , List , Optional , Sequence , Set , Tuple , Union
32
33
from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
33
34
DYNAMO_CONVERTERS as CONVERTERS ,
34
35
)
36
+ from torch_tensorrt .dynamo .debug ._DebuggerConfig import DebuggerConfig
37
+ from torch_tensorrt .dynamo .debug ._supports_debugger import fn_supports_debugger
35
38
from torch_tensorrt .dynamo .lowering import (
36
39
get_decompositions ,
37
40
post_lowering ,
43
46
get_output_metadata ,
44
47
parse_graph_io ,
45
48
prepare_inputs ,
46
- set_log_level ,
47
49
to_torch_device ,
48
50
to_torch_tensorrt_device ,
49
51
)
@@ -66,7 +68,6 @@ def cross_compile_for_windows(
66
68
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
67
69
] = _defaults .ENABLED_PRECISIONS ,
68
70
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
69
- debug : bool = _defaults .DEBUG ,
70
71
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
71
72
workspace_size : int = _defaults .WORKSPACE_SIZE ,
72
73
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -140,7 +141,6 @@ def cross_compile_for_windows(
140
141
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
141
142
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
142
143
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
143
- debug (bool): Enable debuggable engine
144
144
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
145
145
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
146
146
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -187,8 +187,12 @@ def cross_compile_for_windows(
187
187
f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: { platform .system ()= } , { platform .architecture ()[0 ]= } "
188
188
)
189
189
190
- if debug :
191
- set_log_level (logger .parent , logging .DEBUG )
190
+ if kwargs .get ("debug" , False ):
191
+ warnings .warn (
192
+ "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality." ,
193
+ DeprecationWarning ,
194
+ stacklevel = 2 ,
195
+ )
192
196
193
197
if "truncate_long_and_double" in kwargs .keys ():
194
198
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
@@ -299,7 +303,6 @@ def cross_compile_for_windows(
299
303
"enabled_precisions" : (
300
304
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
301
305
),
302
- "debug" : debug ,
303
306
"device" : device ,
304
307
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
305
308
"workspace_size" : workspace_size ,
@@ -401,7 +404,6 @@ def compile(
401
404
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
402
405
] = _defaults .ENABLED_PRECISIONS ,
403
406
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
404
- debug : bool = _defaults .DEBUG ,
405
407
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
406
408
workspace_size : int = _defaults .WORKSPACE_SIZE ,
407
409
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -477,7 +479,6 @@ def compile(
477
479
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
478
480
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
479
481
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
480
- debug (bool): Enable debuggable engine
481
482
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
482
483
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
483
484
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -520,8 +521,13 @@ def compile(
520
521
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
521
522
"""
522
523
523
- if debug :
524
- set_log_level (logger .parent , logging .DEBUG )
524
+ if kwargs .get ("debug" , False ):
525
+ warnings .warn (
526
+ "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality" ,
527
+ DeprecationWarning ,
528
+ stacklevel = 2 ,
529
+ )
530
+
525
531
if "truncate_long_and_double" in kwargs .keys ():
526
532
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
527
533
raise ValueError (
@@ -643,7 +649,6 @@ def compile(
643
649
"enabled_precisions" : (
644
650
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
645
651
),
646
- "debug" : debug ,
647
652
"device" : device ,
648
653
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
649
654
"workspace_size" : workspace_size ,
@@ -718,12 +723,15 @@ def compile(
718
723
return trt_gm
719
724
720
725
726
+ @fn_supports_debugger
721
727
def compile_module (
722
728
gm : torch .fx .GraphModule ,
723
729
sample_arg_inputs : Sequence [Input ],
724
730
sample_kwarg_inputs : Optional [dict [Any , Any ]] = None ,
725
731
settings : CompilationSettings = CompilationSettings (),
726
732
engine_cache : Optional [BaseEngineCache ] = None ,
733
+ * ,
734
+ _debugger_config : Optional [DebuggerConfig ] = None ,
727
735
) -> torch .fx .GraphModule :
728
736
"""Compile a traced FX module
729
737
@@ -747,7 +755,7 @@ def compile_module(
747
755
748
756
# Check the number of supported operations in the graph
749
757
num_supported_ops , total_ops = partitioning .get_graph_converter_support (
750
- gm , settings .debug , settings . torch_executed_ops
758
+ gm , settings .torch_executed_ops
751
759
)
752
760
753
761
dryrun_tracker .total_ops_in_graph = total_ops
@@ -799,7 +807,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
799
807
logger .info ("Partitioning the graph via the fast partitioner" )
800
808
partitioned_module , supported_ops = partitioning .fast_partition (
801
809
gm ,
802
- verbose = settings .debug ,
803
810
min_block_size = settings .min_block_size ,
804
811
torch_executed_ops = settings .torch_executed_ops ,
805
812
require_full_compilation = settings .require_full_compilation ,
@@ -820,7 +827,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
820
827
logger .info ("Partitioning the graph via the global partitioner" )
821
828
partitioned_module , supported_ops = partitioning .global_partition (
822
829
gm ,
823
- verbose = settings .debug ,
824
830
min_block_size = settings .min_block_size ,
825
831
torch_executed_ops = settings .torch_executed_ops ,
826
832
require_full_compilation = settings .require_full_compilation ,
@@ -928,6 +934,41 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
928
934
929
935
trt_modules [name ] = trt_module
930
936
937
+ if _debugger_config :
938
+
939
+ if _debugger_config .save_engine_profile :
940
+ if settings .use_python_runtime :
941
+ if _debugger_config .profile_format != "cudagraph" :
942
+ raise ValueError (
943
+ "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
944
+ )
945
+ else :
946
+ trt_module .enable_profiling ()
947
+ else :
948
+ if _debugger_config .profile_format == "cudagraph" :
949
+ raise ValueError (
950
+ "Profiling with Cudagraph can only be enabled when using the Python runtime. C++ runtime profiling only support TREX/Perfetto visualization."
951
+ )
952
+ else :
953
+ path = os .path .join (
954
+ _debugger_config .logging_dir ,
955
+ "engine_visualization_profile" ,
956
+ )
957
+ os .makedirs (path , exist_ok = True )
958
+ trt_module .enable_profiling (
959
+ profiling_results_dir = path ,
960
+ profile_format = _debugger_config .profile_format ,
961
+ )
962
+
963
+ if _debugger_config .save_layer_info :
964
+ with open (
965
+ os .path .join (
966
+ _debugger_config .logging_dir , "engine_layer_info.json"
967
+ ),
968
+ "w" ,
969
+ ) as f :
970
+ f .write (trt_module .get_layer_info ())
971
+
931
972
# Parse the graph I/O and store it in dryrun tracker
932
973
parse_graph_io (gm , dryrun_tracker )
933
974
@@ -955,7 +996,6 @@ def convert_exported_program_to_serialized_trt_engine(
955
996
enabled_precisions : (
956
997
Set [torch .dtype | dtype ] | Tuple [torch .dtype | dtype ]
957
998
) = _defaults .ENABLED_PRECISIONS ,
958
- debug : bool = _defaults .DEBUG ,
959
999
assume_dynamic_shape_support : bool = _defaults .ASSUME_DYNAMIC_SHAPE_SUPPORT ,
960
1000
workspace_size : int = _defaults .WORKSPACE_SIZE ,
961
1001
min_block_size : int = _defaults .MIN_BLOCK_SIZE ,
@@ -1017,7 +1057,6 @@ def convert_exported_program_to_serialized_trt_engine(
1017
1057
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
1018
1058
]
1019
1059
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
1020
- debug (bool): Whether to print out verbose debugging information
1021
1060
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
1022
1061
min_block_size (int): Minimum number of operators per TRT-Engine Block
1023
1062
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
@@ -1057,8 +1096,12 @@ def convert_exported_program_to_serialized_trt_engine(
1057
1096
Returns:
1058
1097
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
1059
1098
"""
1060
- if debug :
1061
- set_log_level (logger .parent , logging .DEBUG )
1099
+ if kwargs .get ("debug" , False ):
1100
+ warnings .warn (
1101
+ "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality." ,
1102
+ DeprecationWarning ,
1103
+ stacklevel = 2 ,
1104
+ )
1062
1105
1063
1106
if "truncate_long_and_double" in kwargs .keys ():
1064
1107
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
@@ -1142,7 +1185,6 @@ def convert_exported_program_to_serialized_trt_engine(
1142
1185
compilation_options = {
1143
1186
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
1144
1187
"enabled_precisions" : enabled_precisions ,
1145
- "debug" : debug ,
1146
1188
"workspace_size" : workspace_size ,
1147
1189
"min_block_size" : min_block_size ,
1148
1190
"torch_executed_ops" : torch_executed_ops ,
0 commit comments