Skip to content

Commit cc328e2

Browse files
FX graph visualization (#3528)
Co-authored-by: Naren Dasan <[email protected]>
1 parent 1c00f0f commit cc328e2

28 files changed

+560
-80
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,16 @@ void TRTEngine::enable_profiling() {
281281
exec_ctx->setProfiler(trt_engine_profiler.get());
282282
}
283283

284+
void TRTEngine::set_profile_format(std::string format) {
285+
if (format == "trex") {
286+
this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX);
287+
} else if (format == "perfetto") {
288+
this->trt_engine_profiler->set_profile_format(TraceFormat::kPERFETTO);
289+
} else {
290+
TORCHTRT_THROW_ERROR("Invalid profile format: " + format);
291+
}
292+
}
293+
284294
std::string TRTEngine::get_engine_layer_info() {
285295
auto inspector = cuda_engine->createEngineInspector();
286296
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
@@ -315,7 +325,7 @@ void TRTEngine::set_profiling_paths() {
315325
output_profile_path = std::filesystem::path{profile_path_prefix + "/" + name + "_output_profile.trace"}.string();
316326
enqueue_profile_path = std::filesystem::path{profile_path_prefix + "/" + name + "_enqueue_profile.trace"}.string();
317327
trt_engine_profile_path =
318-
std::filesystem::path{profile_path_prefix + "/" + name + "_engine_exectuion_profile.trace"}.string();
328+
std::filesystem::path{profile_path_prefix + "/" + name + "_engine_execution_profile.trace"}.string();
319329
cuda_graph_debug_path = std::filesystem::path{profile_path_prefix + "/" + name + "_cudagraph.dot"}.string();
320330
}
321331

core/runtime/TRTEngine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ struct TRTEngine : torch::CustomClassHolder {
147147
std::string to_str() const;
148148
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
149149
void enable_profiling();
150+
void set_profile_format(std::string profile_format);
150151
void disable_profiling();
151152
std::string get_engine_layer_info();
152153

core/runtime/TRTEngineProfiler.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,40 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector<
3232
}
3333
}
3434

35+
void TRTEngineProfiler::set_profile_format(TraceFormat format) {
36+
this->profile_format = format;
37+
}
38+
3539
void dump_trace(const std::string& path, const TRTEngineProfiler& value) {
3640
std::stringstream out;
3741
out << "[" << std::endl;
3842
double ts = 0.0;
43+
double running_time = 0.0;
44+
for (size_t i = 0; i < value.layer_names.size(); i++) {
45+
auto layer_name = value.layer_names[i];
46+
auto elem = value.profile.at(layer_name);
47+
ts += elem.time;
48+
}
3949
for (size_t i = 0; i < value.layer_names.size(); i++) {
4050
auto layer_name = value.layer_names[i];
4151
auto elem = value.profile.at(layer_name);
4252

4353
out << " {" << std::endl;
4454
out << " \"name\": \"" << layer_name << "\"," << std::endl;
45-
out << " \"ph\": \"X\"," << std::endl;
46-
out << " \"ts\": " << ts * 1000 << "," << std::endl;
47-
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
48-
out << " \"tid\": 1," << std::endl;
49-
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
50-
out << " \"args\": {}" << std::endl;
55+
if (value.profile_format == TraceFormat::kPERFETTO) {
56+
out << " \"ph\": \"X\"," << std::endl;
57+
out << " \"ts\": " << running_time * 1000 << "," << std::endl;
58+
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
59+
out << " \"tid\": 1," << std::endl;
60+
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
61+
out << " \"args\": {}" << std::endl;
62+
} else { // kTREX
63+
out << " \"timeMs\": " << elem.time << "," << std::endl;
64+
out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl;
65+
out << " \"percentage\": " << (elem.time * 100.0 / ts) << std::endl;
66+
}
5167
out << " }," << std::endl;
52-
53-
ts += elem.time;
68+
running_time += elem.time;
5469
}
5570
out.seekp(-2, out.cur);
5671
out << "\n]" << std::endl;

core/runtime/TRTEngineProfiler.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ namespace torch_tensorrt {
1010
namespace core {
1111
namespace runtime {
1212

13+
enum TraceFormat { kPERFETTO, kTREX };
14+
1315
struct TRTEngineProfiler : public nvinfer1::IProfiler {
1416
struct Record {
1517
float time{0};
1618
int count{0};
1719
};
18-
20+
void set_profile_format(TraceFormat format);
1921
virtual void reportLayerTime(const char* layerName, float ms) noexcept;
2022
TRTEngineProfiler(
2123
const std::string& name,
@@ -27,6 +29,7 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler {
2729
std::string name;
2830
std::vector<std::string> layer_names;
2931
std::map<std::string, Record> profile;
32+
TraceFormat profile_format = TraceFormat::kPERFETTO;
3033
};
3134

3235
} // namespace runtime

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8282
.def("__repr__", &TRTEngine::to_str)
8383
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
8484
.def("enable_profiling", &TRTEngine::enable_profiling)
85+
.def("set_profile_format", &TRTEngine::set_profile_format)
8586
.def("disable_profiling", &TRTEngine::disable_profiling)
8687
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
8788
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from ._settings import CompilationSettings
2020
from ._SourceIR import SourceIR
2121
from ._tracer import trace
22+
from .debug._Debugger import Debugger

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import os
56
import platform
67
import warnings
78
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -32,6 +33,8 @@
3233
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3334
DYNAMO_CONVERTERS as CONVERTERS,
3435
)
36+
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
37+
from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger
3538
from torch_tensorrt.dynamo.lowering import (
3639
get_decompositions,
3740
post_lowering,
@@ -43,7 +46,6 @@
4346
get_output_metadata,
4447
parse_graph_io,
4548
prepare_inputs,
46-
set_log_level,
4749
to_torch_device,
4850
to_torch_tensorrt_device,
4951
)
@@ -66,7 +68,6 @@ def cross_compile_for_windows(
6668
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
6769
] = _defaults.ENABLED_PRECISIONS,
6870
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
69-
debug: bool = _defaults.DEBUG,
7071
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
7172
workspace_size: int = _defaults.WORKSPACE_SIZE,
7273
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -140,7 +141,6 @@ def cross_compile_for_windows(
140141
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
141142
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
142143
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
143-
debug (bool): Enable debuggable engine
144144
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
145145
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
146146
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -187,8 +187,12 @@ def cross_compile_for_windows(
187187
f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}"
188188
)
189189

190-
if debug:
191-
set_log_level(logger.parent, logging.DEBUG)
190+
if kwargs.get("debug", False):
191+
warnings.warn(
192+
"`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.",
193+
DeprecationWarning,
194+
stacklevel=2,
195+
)
192196

193197
if "truncate_long_and_double" in kwargs.keys():
194198
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
@@ -299,7 +303,6 @@ def cross_compile_for_windows(
299303
"enabled_precisions": (
300304
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
301305
),
302-
"debug": debug,
303306
"device": device,
304307
"assume_dynamic_shape_support": assume_dynamic_shape_support,
305308
"workspace_size": workspace_size,
@@ -401,7 +404,6 @@ def compile(
401404
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
402405
] = _defaults.ENABLED_PRECISIONS,
403406
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
404-
debug: bool = _defaults.DEBUG,
405407
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
406408
workspace_size: int = _defaults.WORKSPACE_SIZE,
407409
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -477,7 +479,6 @@ def compile(
477479
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
478480
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
479481
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
480-
debug (bool): Enable debuggable engine
481482
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
482483
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
483484
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -520,8 +521,13 @@ def compile(
520521
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
521522
"""
522523

523-
if debug:
524-
set_log_level(logger.parent, logging.DEBUG)
524+
if kwargs.get("debug", False):
525+
warnings.warn(
526+
"`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality",
527+
DeprecationWarning,
528+
stacklevel=2,
529+
)
530+
525531
if "truncate_long_and_double" in kwargs.keys():
526532
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
527533
raise ValueError(
@@ -643,7 +649,6 @@ def compile(
643649
"enabled_precisions": (
644650
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
645651
),
646-
"debug": debug,
647652
"device": device,
648653
"assume_dynamic_shape_support": assume_dynamic_shape_support,
649654
"workspace_size": workspace_size,
@@ -718,12 +723,15 @@ def compile(
718723
return trt_gm
719724

720725

726+
@fn_supports_debugger
721727
def compile_module(
722728
gm: torch.fx.GraphModule,
723729
sample_arg_inputs: Sequence[Input],
724730
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
725731
settings: CompilationSettings = CompilationSettings(),
726732
engine_cache: Optional[BaseEngineCache] = None,
733+
*,
734+
_debugger_config: Optional[DebuggerConfig] = None,
727735
) -> torch.fx.GraphModule:
728736
"""Compile a traced FX module
729737
@@ -747,7 +755,7 @@ def compile_module(
747755

748756
# Check the number of supported operations in the graph
749757
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
750-
gm, settings.debug, settings.torch_executed_ops
758+
gm, settings.torch_executed_ops
751759
)
752760

753761
dryrun_tracker.total_ops_in_graph = total_ops
@@ -799,7 +807,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
799807
logger.info("Partitioning the graph via the fast partitioner")
800808
partitioned_module, supported_ops = partitioning.fast_partition(
801809
gm,
802-
verbose=settings.debug,
803810
min_block_size=settings.min_block_size,
804811
torch_executed_ops=settings.torch_executed_ops,
805812
require_full_compilation=settings.require_full_compilation,
@@ -820,7 +827,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
820827
logger.info("Partitioning the graph via the global partitioner")
821828
partitioned_module, supported_ops = partitioning.global_partition(
822829
gm,
823-
verbose=settings.debug,
824830
min_block_size=settings.min_block_size,
825831
torch_executed_ops=settings.torch_executed_ops,
826832
require_full_compilation=settings.require_full_compilation,
@@ -928,6 +934,41 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
928934

929935
trt_modules[name] = trt_module
930936

937+
if _debugger_config:
938+
939+
if _debugger_config.save_engine_profile:
940+
if settings.use_python_runtime:
941+
if _debugger_config.profile_format != "cudagraph":
942+
raise ValueError(
943+
"Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
944+
)
945+
else:
946+
trt_module.enable_profiling()
947+
else:
948+
if _debugger_config.profile_format == "cudagraph":
949+
raise ValueError(
950+
"Profiling with Cudagraph can only be enabled when using the Python runtime. C++ runtime profiling only support TREX/Perfetto visualization."
951+
)
952+
else:
953+
path = os.path.join(
954+
_debugger_config.logging_dir,
955+
"engine_visualization_profile",
956+
)
957+
os.makedirs(path, exist_ok=True)
958+
trt_module.enable_profiling(
959+
profiling_results_dir=path,
960+
profile_format=_debugger_config.profile_format,
961+
)
962+
963+
if _debugger_config.save_layer_info:
964+
with open(
965+
os.path.join(
966+
_debugger_config.logging_dir, "engine_layer_info.json"
967+
),
968+
"w",
969+
) as f:
970+
f.write(trt_module.get_layer_info())
971+
931972
# Parse the graph I/O and store it in dryrun tracker
932973
parse_graph_io(gm, dryrun_tracker)
933974

@@ -955,7 +996,6 @@ def convert_exported_program_to_serialized_trt_engine(
955996
enabled_precisions: (
956997
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
957998
) = _defaults.ENABLED_PRECISIONS,
958-
debug: bool = _defaults.DEBUG,
959999
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
9601000
workspace_size: int = _defaults.WORKSPACE_SIZE,
9611001
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
@@ -1017,7 +1057,6 @@ def convert_exported_program_to_serialized_trt_engine(
10171057
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
10181058
]
10191059
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
1020-
debug (bool): Whether to print out verbose debugging information
10211060
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
10221061
min_block_size (int): Minimum number of operators per TRT-Engine Block
10231062
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
@@ -1057,8 +1096,12 @@ def convert_exported_program_to_serialized_trt_engine(
10571096
Returns:
10581097
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10591098
"""
1060-
if debug:
1061-
set_log_level(logger.parent, logging.DEBUG)
1099+
if kwargs.get("debug", False):
1100+
warnings.warn(
1101+
"`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.",
1102+
DeprecationWarning,
1103+
stacklevel=2,
1104+
)
10621105

10631106
if "truncate_long_and_double" in kwargs.keys():
10641107
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
@@ -1142,7 +1185,6 @@ def convert_exported_program_to_serialized_trt_engine(
11421185
compilation_options = {
11431186
"assume_dynamic_shape_support": assume_dynamic_shape_support,
11441187
"enabled_precisions": enabled_precisions,
1145-
"debug": debug,
11461188
"workspace_size": workspace_size,
11471189
"min_block_size": min_block_size,
11481190
"torch_executed_ops": torch_executed_ops,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
2+
import pwd
23
import tempfile
34

45
import torch
56
from torch_tensorrt._Device import Device
67
from torch_tensorrt._enums import EngineCapability, dtype
78

89
ENABLED_PRECISIONS = {dtype.f32}
9-
DEBUG = False
1010
DEVICE = None
1111
DISABLE_TF32 = False
1212
ASSUME_DYNAMIC_SHAPE_SUPPORT = False
@@ -57,6 +57,9 @@
5757
L2_LIMIT_FOR_TILING = -1
5858
USE_DISTRIBUTED_MODE_TRACE = False
5959
OFFLOAD_MODULE_TO_CPU = False
60+
DEBUG_LOGGING_DIR = os.path.join(
61+
tempfile.gettempdir(), pwd.getpwuid(os.getuid())[0], "torch_tensorrt/debug_logs"
62+
)
6063

6164

6265
def default_device() -> Device:

0 commit comments

Comments
 (0)