Skip to content

Commit 1261983

Browse files
committed
Added engine visualization
1 parent f6a3f86 commit 1261983

13 files changed

+155
-13
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,18 @@ void TRTEngine::enable_profiling() {
281281
exec_ctx->setProfiler(trt_engine_profiler.get());
282282
}
283283

284+
void TRTEngine::set_profile_format(std::string format) {
285+
if (format == "trex") {
286+
profile_format = TraceFormat::kTREX;
287+
} else if (format == "perfetto") {
288+
profile_format = TraceFormat::kPERFETTO;
289+
} else {
290+
TORCHTRT_THROW_ERROR("Invalid profile format: " + format);
291+
}
292+
293+
profile_format = profile_format;
294+
}
295+
284296
std::string TRTEngine::get_engine_layer_info() {
285297
auto inspector = cuda_engine->createEngineInspector();
286298
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);

core/runtime/TRTEngine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ struct TRTEngine : torch::CustomClassHolder {
147147
std::string to_str() const;
148148
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
149149
void enable_profiling();
150+
void set_profile_format(std::string profile_format);
150151
void disable_profiling();
151152
std::string get_engine_layer_info();
152153

@@ -191,6 +192,7 @@ struct TRTEngine : torch::CustomClassHolder {
191192
#else
192193
bool profile_execution = false;
193194
#endif
195+
TraceFormat profile_format = TraceFormat::kPERFETTO;
194196
std::string device_profile_path;
195197
std::string input_profile_path;
196198
std::string output_profile_path;

core/runtime/TRTEngineProfiler.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,36 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector<
3232
}
3333
}
3434

35-
void dump_trace(const std::string& path, const TRTEngineProfiler& value) {
35+
void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format) {
3636
std::stringstream out;
3737
out << "[" << std::endl;
3838
double ts = 0.0;
39+
double running_time = 0.0;
40+
for (size_t i = 0; i < value.layer_names.size(); i++) {
41+
auto layer_name = value.layer_names[i];
42+
auto elem = value.profile.at(layer_name);
43+
ts += elem.time;
44+
}
3945
for (size_t i = 0; i < value.layer_names.size(); i++) {
4046
auto layer_name = value.layer_names[i];
4147
auto elem = value.profile.at(layer_name);
4248

4349
out << " {" << std::endl;
4450
out << " \"name\": \"" << layer_name << "\"," << std::endl;
45-
out << " \"ph\": \"X\"," << std::endl;
46-
out << " \"ts\": " << ts * 1000 << "," << std::endl;
47-
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
48-
out << " \"tid\": 1," << std::endl;
49-
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
50-
out << " \"args\": {}" << std::endl;
51+
if (format == kPERFETTO) {
52+
out << " \"ph\": \"X\"," << std::endl;
53+
out << " \"ts\": " << running_time * 1000 << "," << std::endl;
54+
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
55+
out << " \"tid\": 1," << std::endl;
56+
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
57+
} else { // kTREX
58+
out << " \"timeMs\": " << elem.time << "," << std::endl;
59+
out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl;
60+
out << " \"percentage\": " << (elem.time * 100.0 / ts) << "," << std::endl;
61+
out << " \"args\": {}" << std::endl;
62+
}
5163
out << " }," << std::endl;
52-
53-
ts += elem.time;
64+
running_time += elem.time;
5465
}
5566
out.seekp(-2, out.cur);
5667
out << "\n]" << std::endl;

core/runtime/TRTEngineProfiler.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ namespace torch_tensorrt {
1010
namespace core {
1111
namespace runtime {
1212

13+
enum TraceFormat { kPERFETTO, kTREX };
14+
15+
// Forward declare the function
16+
1317
struct TRTEngineProfiler : public nvinfer1::IProfiler {
1418
struct Record {
1519
float time{0};
@@ -21,7 +25,7 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler {
2125
const std::string& name,
2226
const std::vector<TRTEngineProfiler>& srcProfilers = std::vector<TRTEngineProfiler>());
2327
friend std::ostream& operator<<(std::ostream& out, const TRTEngineProfiler& value);
24-
friend void dump_trace(const std::string& path, const TRTEngineProfiler& value);
28+
friend void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format);
2529

2630
private:
2731
std::string name;

core/runtime/execute_engine.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,10 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
339339

340340
if (compiled_engine->profile_execution) {
341341
LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler);
342-
dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler);
342+
dump_trace(
343+
compiled_engine->trt_engine_profile_path,
344+
*compiled_engine->trt_engine_profiler,
345+
compiled_engine->profile_format);
343346
compiled_engine->dump_engine_layer_info();
344347
}
345348

@@ -440,7 +443,10 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
440443

441444
if (compiled_engine->profile_execution) {
442445
LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler);
443-
dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler);
446+
dump_trace(
447+
compiled_engine->trt_engine_profile_path,
448+
*compiled_engine->trt_engine_profiler,
449+
compiled_engine->profile_format);
444450
compiled_engine->dump_engine_layer_info();
445451
}
446452

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8282
.def("__repr__", &TRTEngine::to_str)
8383
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
8484
.def("enable_profiling", &TRTEngine::enable_profiling)
85+
.def("set_profile_format", &TRTEngine::set_profile_format)
8586
.def("disable_profiling", &TRTEngine::disable_profiling)
8687
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
8788
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import os
56
import platform
67
import warnings
78
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -421,6 +422,7 @@ def compile(
421422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425+
engine_vis_dir: Optional[str] = _defaults.ENGINE_VIS_DIR,
424426
**kwargs: Any,
425427
) -> torch.fx.GraphModule:
426428
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -674,6 +676,7 @@ def compile(
674676
"enable_weight_streaming": enable_weight_streaming,
675677
"tiling_optimization_level": tiling_optimization_level,
676678
"l2_limit_for_tiling": l2_limit_for_tiling,
679+
"engine_vis_dir": engine_vis_dir,
677680
}
678681

679682
settings = CompilationSettings(**compilation_options)
@@ -904,6 +907,19 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
904907

905908
trt_modules[name] = trt_module
906909

910+
if settings.debug and settings.engine_vis_dir:
911+
if settings.use_python_runtime:
912+
logger.warning(
913+
"Profiling can only be enabled when using the C++ runtime"
914+
)
915+
else:
916+
if not os.path.exists(settings.engine_vis_dir):
917+
os.makedirs(settings.engine_vis_dir)
918+
trt_module.enable_profiling(
919+
profiling_results_dir=settings.engine_vis_dir,
920+
profile_format="trex",
921+
)
922+
907923
# Parse the graph I/O and store it in dryrun tracker
908924
parse_graph_io(gm, dryrun_tracker)
909925

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
DLA_SRAM_SIZE = 1048576
1616
ENGINE_CAPABILITY = EngineCapability.STANDARD
1717
WORKSPACE_SIZE = 0
18+
ENGINE_VIS_DIR = None
1819
MIN_BLOCK_SIZE = 5
1920
PASS_THROUGH_BUILD_FAILURES = False
2021
MAX_AUX_STREAMS = None

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ENABLE_WEIGHT_STREAMING,
1919
ENABLED_PRECISIONS,
2020
ENGINE_CAPABILITY,
21+
ENGINE_VIS_DIR,
2122
HARDWARE_COMPATIBLE,
2223
IMMUTABLE_WEIGHTS,
2324
L2_LIMIT_FOR_TILING,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
engine_vis_dir: Optional[str] = ENGINE_VIS_DIR
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
334334

335335
return tuple(outputs)
336336

337-
def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None:
337+
def enable_profiling(
338+
self, profiling_results_dir: Optional[str] = None, profile_format: str = "trex"
339+
) -> None:
338340
"""Enable the profiler to collect latency information about the execution of the engine
339341
340342
Traces can be visualized using https://ui.perfetto.dev/ or compatible alternatives
@@ -347,7 +349,9 @@ def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None:
347349

348350
if profiling_results_dir is not None:
349351
self.engine.profile_path_prefix = profiling_results_dir
352+
assert profile_format in ["trex", "perfetto"]
350353
self.engine.enable_profiling()
354+
self.engine.set_profile_format(profile_format)
351355

352356
def disable_profiling(self) -> None:
353357
"""Disable the profiler"""
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## Introduction
2+
We use the TRT Engine Explorer (TREX) to visualize the engien graph structure. TREX is a diagnostic and profiling tool for TensorRT engine files. It allows you to inspect, benchmark, and debug TensorRT engines with ease.
3+
4+
## Installation
5+
```bash
6+
git clone https://github.com/NVIDIA/TensorRT.git
7+
cd TensorRT/tools/experimental/trt-engine-explorer
8+
python3 -m pip install -e .[notebook]
9+
sudo apt --yes install graphviz
10+
```
11+
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import argparse
2+
import os
3+
import re
4+
import shutil
5+
import subprocess
6+
import warnings
7+
from typing import Tuple
8+
9+
import networkx as nx
10+
import trex
11+
import trex.engine_plan
12+
import trex.graphing
13+
14+
15+
def draw_engine(dir_path: str):
16+
try:
17+
import trex
18+
except ImportError:
19+
print("trex is required but it is not installed.\n")
20+
print("Check README.md for installation instructions.")
21+
exit()
22+
23+
engine_json_fname = os.path.join(
24+
dir_path, "_run_on_acc_0_engine_layer_information.json"
25+
)
26+
profiling_json_fname = os.path.join(
27+
dir_path, "_run_on_acc_0_engine_engine_exectuion_profile.trace"
28+
)
29+
30+
graphviz_is_installed = shutil.which("dot") is not None
31+
if not graphviz_is_installed:
32+
print("graphviz is required but it is not installed.\n")
33+
print("To install on Ubuntu:")
34+
print("sudo apt --yes install graphviz")
35+
exit()
36+
37+
plan = trex.engine_plan.EnginePlan(
38+
engine_json_fname, profiling_file=profiling_json_fname
39+
)
40+
layer_node_formatter = trex.graphing.layer_type_formatter
41+
graph = trex.graphing.to_dot(plan, layer_node_formatter)
42+
output_format = "png" # svg or jpg
43+
44+
trex.graphing.render_dot(graph, engine_json_fname, output_format)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import numpy as np
2+
import torch
3+
import torch_tensorrt as torch_tensorrt
4+
import torchvision.models as models
5+
6+
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
7+
model = models.resnet18(pretrained=False).eval().to("cuda")
8+
exp_program = torch.export.export(model, tuple(inputs))
9+
enabled_precisions = {torch.float}
10+
debug = False
11+
workspace_size = 20 << 30
12+
min_block_size = 0
13+
use_python_runtime = False
14+
torch_executed_ops = {}
15+
trt_gm = torch_tensorrt.dynamo.compile(
16+
exp_program,
17+
inputs=inputs,
18+
enabled_precisions=enabled_precisions,
19+
truncate_double=True,
20+
debug=True,
21+
use_python_runtime=False,
22+
engine_vis_dir="/home/profile",
23+
)
24+
trt_output = trt_gm(*inputs)
25+
26+
from draw_engine_graph import draw_engine
27+
28+
draw_engine("/home/profile")

0 commit comments

Comments
 (0)