From 7da057ea0f12b5b53c999bf7ca6890d7f43af1a2 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 21 May 2025 17:41:48 +0000 Subject: [PATCH 01/12] Added initial attenpt to implement fx graph visualization --- .../dynamo/lowering/passes/draw_fx_graph.py | 53 +++++++++++++++++++ .../dynamo/lowering/passes/pass_manager.py | 19 +++++++ pyproject.toml | 6 +++ 3 files changed, 78 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py b/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py new file mode 100644 index 0000000000..c0d62e4923 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py @@ -0,0 +1,53 @@ +import torch +from torch.fx import passes +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes import LoweringPassSignature + +PRE_DEBUG_NAME = { + 0: "exported_program", + 1: "after_remove_detach,", +} + +POST_DEBUG_NAME = { + 0: "after_decomposition", + 1: "after_remove_input_alias_fixing_clones", + 2: "after_constant_fold", + 3: "after_repair_input_as_output", + 4: "after_fuse_prims_broadcast", + 5: "after_replace_max_pool_with_indices", + 6: "after_remove_assert_nodes", + 7: "after_accumulate_fp32_matmul", + 8: "after_remove_num_users_is_0_nodes", +} + + +def get_draw_fx_graph_pass_post_lowering( + idx: int, path_prefix: str +) -> LoweringPassSignature: + + def draw_fx_graph_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + path = f"{path_prefix}_{POST_DEBUG_NAME[idx]}.svg" + g = passes.graph_drawer.FxGraphDrawer(gm, POST_DEBUG_NAME[idx]) + with open(path, "wb") as f: + f.write(g.get_dot_graph().create_svg()) + return gm + + return draw_fx_graph_pass + + +def get_draw_fx_graph_pass_pre_lowering( + idx: int, path_prefix: str +) -> LoweringPassSignature: + + def draw_fx_graph_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + path = f"{path_prefix}_{PRE_DEBUG_NAME[idx]}.svg" + g = passes.graph_drawer.FxGraphDrawer(gm, PRE_DEBUG_NAME[idx]) + with open(path, "wb") as f: + f.write(g.get_dot_graph().create_svg()) + return gm + + return draw_fx_graph_pass diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index c793b1e1c9..7e07b10971 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -3,6 +3,10 @@ import torch from torch.fx.passes.pass_manager import PassManager from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.draw_fx_graph import ( + get_draw_fx_graph_pass_post_lowering, + get_draw_fx_graph_pass_pre_lowering, +) class DynamoPassManager(PassManager): # type: ignore[misc] @@ -49,6 +53,21 @@ def add_pass_with_index( def remove_pass_with_index(self, index: int) -> None: del self.passes[index] + def insert_debug_pass( + self, index: List[int], filename_prefix: str, post: bool = True + ) -> None: + + for i in range(len(index)): + if post: + debug_pass = get_draw_fx_graph_pass_post_lowering( + index[i], filename_prefix + ) + else: + debug_pass = get_draw_fx_graph_pass_pre_lowering( + index[i], filename_prefix + ) + self.add_pass_with_index(debug_pass, index[i] + i) + def __call__(self, gm: Any, settings: CompilationSettings) -> Any: self.validate() out = gm diff --git a/pyproject.toml b/pyproject.toml index 3bb857e3e0..4e2649cc9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,12 @@ dev = [ "pyyaml", ] +debug = [ + "pydot >= 4.0.0", + "tabulate >= 0.8.10", + "graphviz >= 0.20.3" +] + [project.optional-dependencies] torchvision = [ "torchvision", From b4e5e8f05d3b34a8ea060e688daeddd57b1b8fe3 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 22 May 2025 04:55:10 +0000 Subject: [PATCH 02/12] Added auto generated names --- .../lowering/passes/_aten_lowering_pass.py | 18 +++++++++------- .../dynamo/lowering/passes/draw_fx_graph.py | 21 ++++++++----------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 2ecc45ecf3..c7fe264c5a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -17,7 +17,7 @@ from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices -pass_list = [ +post_lowering_pass_list = [ remove_input_alias_fixing_clones, constant_fold, repair_input_as_output, @@ -28,17 +28,19 @@ remove_num_users_is_0_nodes, ] -if not is_tegra_platform(): - pass_list.append(fuse_distributed_ops) +pre_lowering_pass_list = [ + remove_detach, +] -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) +if not is_tegra_platform(): + post_lowering_pass_list.append(fuse_distributed_ops) -ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - remove_detach, - ] +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + post_lowering_pass_list ) +ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pre_lowering_pass_list) + logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py b/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py index c0d62e4923..6ff64ebc82 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py @@ -1,24 +1,21 @@ import torch from torch.fx import passes from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes import LoweringPassSignature +from torch_tensorrt.dynamo.lowering.passes import ( + LoweringPassSignature, + post_lowering_pass_list, + pre_lowering_pass_list, +) PRE_DEBUG_NAME = { - 0: "exported_program", - 1: "after_remove_detach,", + i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list) } +PRE_DEBUG_NAME[0] = "exported_program" POST_DEBUG_NAME = { - 0: "after_decomposition", - 1: "after_remove_input_alias_fixing_clones", - 2: "after_constant_fold", - 3: "after_repair_input_as_output", - 4: "after_fuse_prims_broadcast", - 5: "after_replace_max_pool_with_indices", - 6: "after_remove_assert_nodes", - 7: "after_accumulate_fp32_matmul", - 8: "after_remove_num_users_is_0_nodes", + i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list) } +POST_DEBUG_NAME[0] = "after_decomposition" def get_draw_fx_graph_pass_post_lowering( From e77197ec6da3871d30d3a60660f87cb1a4bd9202 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 22 May 2025 17:01:19 +0000 Subject: [PATCH 03/12] merged the file to pass_manager --- .../dynamo/lowering/passes/draw_fx_graph.py | 50 ----------------- .../dynamo/lowering/passes/pass_manager.py | 53 +++++++++++++------ 2 files changed, 38 insertions(+), 65 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py b/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py deleted file mode 100644 index 6ff64ebc82..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -from torch.fx import passes -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes import ( - LoweringPassSignature, - post_lowering_pass_list, - pre_lowering_pass_list, -) - -PRE_DEBUG_NAME = { - i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list) -} -PRE_DEBUG_NAME[0] = "exported_program" - -POST_DEBUG_NAME = { - i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list) -} -POST_DEBUG_NAME[0] = "after_decomposition" - - -def get_draw_fx_graph_pass_post_lowering( - idx: int, path_prefix: str -) -> LoweringPassSignature: - - def draw_fx_graph_pass( - gm: torch.fx.GraphModule, settings: CompilationSettings - ) -> torch.fx.GraphModule: - path = f"{path_prefix}_{POST_DEBUG_NAME[idx]}.svg" - g = passes.graph_drawer.FxGraphDrawer(gm, POST_DEBUG_NAME[idx]) - with open(path, "wb") as f: - f.write(g.get_dot_graph().create_svg()) - return gm - - return draw_fx_graph_pass - - -def get_draw_fx_graph_pass_pre_lowering( - idx: int, path_prefix: str -) -> LoweringPassSignature: - - def draw_fx_graph_pass( - gm: torch.fx.GraphModule, settings: CompilationSettings - ) -> torch.fx.GraphModule: - path = f"{path_prefix}_{PRE_DEBUG_NAME[idx]}.svg" - g = passes.graph_drawer.FxGraphDrawer(gm, PRE_DEBUG_NAME[idx]) - with open(path, "wb") as f: - f.write(g.get_dot_graph().create_svg()) - return gm - - return draw_fx_graph_pass diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index 7e07b10971..73be3b2400 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -1,12 +1,40 @@ -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, Callable, List, Optional import torch +from torch.fx import passes from torch.fx.passes.pass_manager import PassManager from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes.draw_fx_graph import ( - get_draw_fx_graph_pass_post_lowering, - get_draw_fx_graph_pass_pre_lowering, -) + + +def get_draw_fx_graph_pass_lowering( + idx: int, path_prefix: str, post: bool +) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]: + from torch_tensorrt.dynamo.lowering.passes import ( + post_lowering_pass_list, + pre_lowering_pass_list, + ) + + PRE_DEBUG_NAME = { + i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list) + } + PRE_DEBUG_NAME[0] = "exported_program" + + POST_DEBUG_NAME = { + i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list) + } + POST_DEBUG_NAME[0] = "after_decomposition" + + def draw_fx_graph_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + DEBUG_NAME = POST_DEBUG_NAME[idx] if post else PRE_DEBUG_NAME[idx] + path = f"{path_prefix}_{DEBUG_NAME}.svg" + g = passes.graph_drawer.FxGraphDrawer(gm, DEBUG_NAME) + with open(path, "wb") as f: + f.write(g.get_dot_graph().create_svg()) + return gm + + return draw_fx_graph_pass class DynamoPassManager(PassManager): # type: ignore[misc] @@ -39,8 +67,7 @@ def build_from_passlist( def add_pass_with_index( self, lowering_pass: Callable[ - [torch.fx.GraphModule, CompilationSettings, Sequence[torch.Tensor]], - torch.fx.GraphModule, + [torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule ], index: Optional[int] = None, ) -> None: @@ -58,14 +85,10 @@ def insert_debug_pass( ) -> None: for i in range(len(index)): - if post: - debug_pass = get_draw_fx_graph_pass_post_lowering( - index[i], filename_prefix - ) - else: - debug_pass = get_draw_fx_graph_pass_pre_lowering( - index[i], filename_prefix - ) + + debug_pass = get_draw_fx_graph_pass_lowering( + index[i], filename_prefix, post + ) self.add_pass_with_index(debug_pass, index[i] + i) def __call__(self, gm: Any, settings: CompilationSettings) -> Any: From 42a0365f98a83c35ecb670a6b8b995f6267e471c Mon Sep 17 00:00:00 2001 From: Naren Dasan <1790613+narendasan@users.noreply.github.com> Date: Thu, 22 May 2025 22:31:50 -0600 Subject: [PATCH 04/12] Simplify pass manager debug system (#3530) --- .../dynamo/lowering/passes/pass_manager.py | 77 ++++++++++++------- 1 file changed, 48 insertions(+), 29 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index 73be3b2400..7dbaf70571 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, List, Optional +import tempfile +from types import new_class +from typing import Any, Callable, List, Optional, Union import torch from torch.fx import passes @@ -6,30 +8,14 @@ from torch_tensorrt.dynamo._settings import CompilationSettings -def get_draw_fx_graph_pass_lowering( - idx: int, path_prefix: str, post: bool +def _generate_draw_fx_graph_pass( + output_path_prefix: str, name: str ) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]: - from torch_tensorrt.dynamo.lowering.passes import ( - post_lowering_pass_list, - pre_lowering_pass_list, - ) - - PRE_DEBUG_NAME = { - i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list) - } - PRE_DEBUG_NAME[0] = "exported_program" - - POST_DEBUG_NAME = { - i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list) - } - POST_DEBUG_NAME[0] = "after_decomposition" - def draw_fx_graph_pass( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: - DEBUG_NAME = POST_DEBUG_NAME[idx] if post else PRE_DEBUG_NAME[idx] - path = f"{path_prefix}_{DEBUG_NAME}.svg" - g = passes.graph_drawer.FxGraphDrawer(gm, DEBUG_NAME) + path = f"{output_path_prefix}/{name}.svg" + g = passes.graph_drawer.FxGraphDrawer(gm, name) with open(path, "wb") as f: f.write(g.get_dot_graph().create_svg()) return gm @@ -47,8 +33,9 @@ def __init__( ] ] ] = None, + constraints: Optional[List[Callable]] = None ): - super().__init__(passes) + super().__init__(passes, constraints) @classmethod def build_from_passlist( @@ -80,16 +67,48 @@ def add_pass_with_index( def remove_pass_with_index(self, index: int) -> None: del self.passes[index] - def insert_debug_pass( - self, index: List[int], filename_prefix: str, post: bool = True + def insert_debug_pass_before( + self, passes: List[str], output_path_prefix: str=tempfile.gettempdir() ) -> None: + """Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass. + + Args: + passes: List of pass names to insert debug passes before + output_path_prefix: Prefix to use for generated debug files + + Debug passes generate SVG visualizations of the FX graph at specified points + in the pass sequence. + """ + new_pass_list = [] + for ps in self.passes: + if ps.__name__ in passes: + new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}")) + new_pass_list.append(ps) + + self.passes = new_pass_list + self._validated = False + + def insert_debug_pass_after( + self, passes: List[str], output_path_prefix: str=tempfile.gettempdir() + ) -> None: + """Insert debug passes in the PassManager pass sequence after the execution of a particular pass. + + Args: + passes: List of pass names to insert debug passes after + output_path_prefix: Prefix to use for generated debug files + + Debug passes generate SVG visualizations of the FX graph at specified points + in the pass sequence. + """ + new_pass_list = [] + for ps in self.passes: + new_pass_list.append(ps) + if ps.__name__ in passes: + new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}")) - for i in range(len(index)): - debug_pass = get_draw_fx_graph_pass_lowering( - index[i], filename_prefix, post - ) - self.add_pass_with_index(debug_pass, index[i] + i) + self.passes = new_pass_list + self._validated = False def __call__(self, gm: Any, settings: CompilationSettings) -> Any: self.validate() From 7dc1618e908fb1bd4333dcca802abf5097fd5e4d Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 27 May 2025 18:24:28 +0000 Subject: [PATCH 05/12] Added pass name check --- .../dynamo/lowering/passes/pass_manager.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index 7dbaf70571..c55897ff45 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -1,6 +1,6 @@ +import os import tempfile -from types import new_class -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional import torch from torch.fx import passes @@ -14,6 +14,8 @@ def _generate_draw_fx_graph_pass( def draw_fx_graph_pass( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: + if not os.path.exists(f"{output_path_prefix}/"): + os.makedirs(f"{output_path_prefix}/") path = f"{output_path_prefix}/{name}.svg" g = passes.graph_drawer.FxGraphDrawer(gm, name) with open(path, "wb") as f: @@ -33,7 +35,7 @@ def __init__( ] ] ] = None, - constraints: Optional[List[Callable]] = None + constraints: Optional[List[Callable]] = None, ): super().__init__(passes, constraints) @@ -68,7 +70,7 @@ def remove_pass_with_index(self, index: int) -> None: del self.passes[index] def insert_debug_pass_before( - self, passes: List[str], output_path_prefix: str=tempfile.gettempdir() + self, passes: List[str], output_path_prefix: str = tempfile.gettempdir() ) -> None: """Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass. @@ -79,17 +81,22 @@ def insert_debug_pass_before( Debug passes generate SVG visualizations of the FX graph at specified points in the pass sequence. """ + self.check_pass_names_valid(passes) new_pass_list = [] for ps in self.passes: if ps.__name__ in passes: - new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}")) + new_pass_list.append( + _generate_draw_fx_graph_pass( + output_path_prefix, f"before_{ps.__name__}" + ) + ) new_pass_list.append(ps) self.passes = new_pass_list self._validated = False def insert_debug_pass_after( - self, passes: List[str], output_path_prefix: str=tempfile.gettempdir() + self, passes: List[str], output_path_prefix: str = tempfile.gettempdir() ) -> None: """Insert debug passes in the PassManager pass sequence after the execution of a particular pass. @@ -100,16 +107,27 @@ def insert_debug_pass_after( Debug passes generate SVG visualizations of the FX graph at specified points in the pass sequence. """ + self.check_pass_names_valid(passes) new_pass_list = [] for ps in self.passes: new_pass_list.append(ps) if ps.__name__ in passes: - new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}")) - + new_pass_list.append( + _generate_draw_fx_graph_pass( + output_path_prefix, f"after_{ps.__name__}" + ) + ) self.passes = new_pass_list self._validated = False + def check_pass_names_valid(self, debug_pass_names: List[str]) -> None: + pass_names_str = [p.__name__ for p in self.passes] + for name in debug_pass_names: + assert ( + name in pass_names_str + ), f"{name} is not a valid pass! Passes: {pass_names_str}" + def __call__(self, gm: Any, settings: CompilationSettings) -> Any: self.validate() out = gm From be23d490569a5cdcf57b8f051a5b86984ee17721 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 28 May 2025 01:29:37 +0000 Subject: [PATCH 06/12] Added engine visualization --- core/runtime/TRTEngine.cpp | 12 +++++ core/runtime/TRTEngine.h | 2 + core/runtime/TRTEngineProfiler.cpp | 29 ++++++++---- core/runtime/TRTEngineProfiler.h | 6 ++- core/runtime/execute_engine.cpp | 10 ++++- core/runtime/register_jit_hooks.cpp | 1 + py/torch_tensorrt/dynamo/_compiler.py | 14 ++++++ py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 1 + .../dynamo/runtime/_TorchTensorRTModule.py | 6 ++- tools/debug/engine_visualization/README.md | 11 +++++ .../engine_visualization/draw_engine_graph.py | 44 +++++++++++++++++++ tools/debug/engine_visualization/llama_hlo.py | 28 ++++++++++++ 13 files changed, 152 insertions(+), 13 deletions(-) create mode 100644 tools/debug/engine_visualization/README.md create mode 100644 tools/debug/engine_visualization/draw_engine_graph.py create mode 100644 tools/debug/engine_visualization/llama_hlo.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 9a04aba6de..6fd067a20f 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -281,6 +281,18 @@ void TRTEngine::enable_profiling() { exec_ctx->setProfiler(trt_engine_profiler.get()); } +void TRTEngine::set_profile_format(std::string format) { + if (format == "trex") { + profile_format = TraceFormat::kTREX; + } else if (format == "perfetto") { + profile_format = TraceFormat::kPERFETTO; + } else { + TORCHTRT_THROW_ERROR("Invalid profile format: " + format); + } + + profile_format = profile_format; +} + std::string TRTEngine::get_engine_layer_info() { auto inspector = cuda_engine->createEngineInspector(); return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 2db640b6b1..23bada84cd 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -147,6 +147,7 @@ struct TRTEngine : torch::CustomClassHolder { std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); void enable_profiling(); + void set_profile_format(std::string profile_format); void disable_profiling(); std::string get_engine_layer_info(); @@ -191,6 +192,7 @@ struct TRTEngine : torch::CustomClassHolder { #else bool profile_execution = false; #endif + TraceFormat profile_format = TraceFormat::kPERFETTO; std::string device_profile_path; std::string input_profile_path; std::string output_profile_path; diff --git a/core/runtime/TRTEngineProfiler.cpp b/core/runtime/TRTEngineProfiler.cpp index 8f7f0ac4e9..5996a75e85 100644 --- a/core/runtime/TRTEngineProfiler.cpp +++ b/core/runtime/TRTEngineProfiler.cpp @@ -32,25 +32,36 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector< } } -void dump_trace(const std::string& path, const TRTEngineProfiler& value) { +void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format) { std::stringstream out; out << "[" << std::endl; double ts = 0.0; + double running_time = 0.0; + for (size_t i = 0; i < value.layer_names.size(); i++) { + auto layer_name = value.layer_names[i]; + auto elem = value.profile.at(layer_name); + ts += elem.time; + } for (size_t i = 0; i < value.layer_names.size(); i++) { auto layer_name = value.layer_names[i]; auto elem = value.profile.at(layer_name); out << " {" << std::endl; out << " \"name\": \"" << layer_name << "\"," << std::endl; - out << " \"ph\": \"X\"," << std::endl; - out << " \"ts\": " << ts * 1000 << "," << std::endl; - out << " \"dur\": " << elem.time * 1000 << "," << std::endl; - out << " \"tid\": 1," << std::endl; - out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl; - out << " \"args\": {}" << std::endl; + if (format == kPERFETTO) { + out << " \"ph\": \"X\"," << std::endl; + out << " \"ts\": " << running_time * 1000 << "," << std::endl; + out << " \"dur\": " << elem.time * 1000 << "," << std::endl; + out << " \"tid\": 1," << std::endl; + out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl; + } else { // kTREX + out << " \"timeMs\": " << elem.time << "," << std::endl; + out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl; + out << " \"percentage\": " << (elem.time * 100.0 / ts) << "," << std::endl; + out << " \"args\": {}" << std::endl; + } out << " }," << std::endl; - - ts += elem.time; + running_time += elem.time; } out.seekp(-2, out.cur); out << "\n]" << std::endl; diff --git a/core/runtime/TRTEngineProfiler.h b/core/runtime/TRTEngineProfiler.h index 34a901165b..682fa3889d 100644 --- a/core/runtime/TRTEngineProfiler.h +++ b/core/runtime/TRTEngineProfiler.h @@ -10,6 +10,10 @@ namespace torch_tensorrt { namespace core { namespace runtime { +enum TraceFormat { kPERFETTO, kTREX }; + +// Forward declare the function + struct TRTEngineProfiler : public nvinfer1::IProfiler { struct Record { float time{0}; @@ -21,7 +25,7 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler { const std::string& name, const std::vector& srcProfilers = std::vector()); friend std::ostream& operator<<(std::ostream& out, const TRTEngineProfiler& value); - friend void dump_trace(const std::string& path, const TRTEngineProfiler& value); + friend void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format); private: std::string name; diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..e24ea8df29 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -339,7 +339,10 @@ std::vector execute_engine(std::vector inputs, c10::intr if (compiled_engine->profile_execution) { LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); - dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); + dump_trace( + compiled_engine->trt_engine_profile_path, + *compiled_engine->trt_engine_profiler, + compiled_engine->profile_format); compiled_engine->dump_engine_layer_info(); } @@ -440,7 +443,10 @@ std::vector execute_engine(std::vector inputs, c10::intr if (compiled_engine->profile_execution) { LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); - dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); + dump_trace( + compiled_engine->trt_engine_profile_path, + *compiled_engine->trt_engine_profiler, + compiled_engine->profile_format); compiled_engine->dump_engine_layer_info(); } diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index cbe19b0af6..173ff8c35f 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -82,6 +82,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("__repr__", &TRTEngine::to_str) .def("__obj_flatten__", &TRTEngine::__obj_flatten__) .def("enable_profiling", &TRTEngine::enable_profiling) + .def("set_profile_format", &TRTEngine::set_profile_format) .def("disable_profiling", &TRTEngine::disable_profiling) .def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix) .def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 831ce37305..c8541734d9 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,6 +2,7 @@ import collections.abc import logging +import os import platform import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union @@ -925,6 +926,19 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules[name] = trt_module + if settings.debug and settings.engine_vis_dir: + if settings.use_python_runtime: + logger.warning( + "Profiling can only be enabled when using the C++ runtime" + ) + else: + if not os.path.exists(settings.engine_vis_dir): + os.makedirs(settings.engine_vis_dir) + trt_module.enable_profiling( + profiling_results_dir=settings.engine_vis_dir, + profile_format="trex", + ) + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index aafd1072f4..6bea10171a 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -15,6 +15,7 @@ DLA_SRAM_SIZE = 1048576 ENGINE_CAPABILITY = EngineCapability.STANDARD WORKSPACE_SIZE = 0 +ENGINE_VIS_DIR = None MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False MAX_AUX_STREAMS = None diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 97c02f34fb..da56a780d4 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -18,6 +18,7 @@ ENABLE_WEIGHT_STREAMING, ENABLED_PRECISIONS, ENGINE_CAPABILITY, + ENGINE_VIS_DIR, HARDWARE_COMPATIBLE, IMMUTABLE_WEIGHTS, L2_LIMIT_FOR_TILING, diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index c3fe925eee..aa964998be 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -334,7 +334,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: return tuple(outputs) - def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None: + def enable_profiling( + self, profiling_results_dir: Optional[str] = None, profile_format: str = "trex" + ) -> None: """Enable the profiler to collect latency information about the execution of the engine Traces can be visualized using https://ui.perfetto.dev/ or compatible alternatives @@ -347,7 +349,9 @@ def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None: if profiling_results_dir is not None: self.engine.profile_path_prefix = profiling_results_dir + assert profile_format in ["trex", "perfetto"] self.engine.enable_profiling() + self.engine.set_profile_format(profile_format) def disable_profiling(self) -> None: """Disable the profiler""" diff --git a/tools/debug/engine_visualization/README.md b/tools/debug/engine_visualization/README.md new file mode 100644 index 0000000000..cacc5543ab --- /dev/null +++ b/tools/debug/engine_visualization/README.md @@ -0,0 +1,11 @@ +## Introduction +We use the TRT Engine Explorer (TREX) to visualize the engien graph structure. TREX is a diagnostic and profiling tool for TensorRT engine files. It allows you to inspect, benchmark, and debug TensorRT engines with ease. + +## Installation +```bash +git clone https://github.com/NVIDIA/TensorRT.git +cd TensorRT/tools/experimental/trt-engine-explorer +python3 -m pip install -e .[notebook] +sudo apt --yes install graphviz +``` + diff --git a/tools/debug/engine_visualization/draw_engine_graph.py b/tools/debug/engine_visualization/draw_engine_graph.py new file mode 100644 index 0000000000..4de5eafd0d --- /dev/null +++ b/tools/debug/engine_visualization/draw_engine_graph.py @@ -0,0 +1,44 @@ +import argparse +import os +import re +import shutil +import subprocess +import warnings +from typing import Tuple + +import networkx as nx +import trex +import trex.engine_plan +import trex.graphing + + +def draw_engine(dir_path: str): + try: + import trex + except ImportError: + print("trex is required but it is not installed.\n") + print("Check README.md for installation instructions.") + exit() + + engine_json_fname = os.path.join( + dir_path, "_run_on_acc_0_engine_layer_information.json" + ) + profiling_json_fname = os.path.join( + dir_path, "_run_on_acc_0_engine_engine_exectuion_profile.trace" + ) + + graphviz_is_installed = shutil.which("dot") is not None + if not graphviz_is_installed: + print("graphviz is required but it is not installed.\n") + print("To install on Ubuntu:") + print("sudo apt --yes install graphviz") + exit() + + plan = trex.engine_plan.EnginePlan( + engine_json_fname, profiling_file=profiling_json_fname + ) + layer_node_formatter = trex.graphing.layer_type_formatter + graph = trex.graphing.to_dot(plan, layer_node_formatter) + output_format = "png" # svg or jpg + + trex.graphing.render_dot(graph, engine_json_fname, output_format) diff --git a/tools/debug/engine_visualization/llama_hlo.py b/tools/debug/engine_visualization/llama_hlo.py new file mode 100644 index 0000000000..2d11ad0753 --- /dev/null +++ b/tools/debug/engine_visualization/llama_hlo.py @@ -0,0 +1,28 @@ +import numpy as np +import torch +import torch_tensorrt as torch_tensorrt +import torchvision.models as models + +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] +model = models.resnet18(pretrained=False).eval().to("cuda") +exp_program = torch.export.export(model, tuple(inputs)) +enabled_precisions = {torch.float} +debug = False +workspace_size = 20 << 30 +min_block_size = 0 +use_python_runtime = False +torch_executed_ops = {} +trt_gm = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + enabled_precisions=enabled_precisions, + truncate_double=True, + debug=True, + use_python_runtime=False, + engine_vis_dir="/home/profile", +) +trt_output = trt_gm(*inputs) + +from draw_engine_graph import draw_engine + +draw_engine("/home/profile") From 8de3947db0552dd15a09ad2674729ae1bbb85d35 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 28 May 2025 21:08:39 +0000 Subject: [PATCH 07/12] Fixed the comments, changed dump function --- core/runtime/TRTEngine.cpp | 6 ++---- core/runtime/TRTEngine.h | 1 - core/runtime/TRTEngineProfiler.cpp | 10 +++++++--- core/runtime/TRTEngineProfiler.h | 5 +++-- core/runtime/execute_engine.cpp | 10 ++-------- .../dynamo/runtime/_TorchTensorRTModule.py | 4 +++- .../{llama_hlo.py => draw_engine_graph_example.py} | 0 7 files changed, 17 insertions(+), 19 deletions(-) rename tools/debug/engine_visualization/{llama_hlo.py => draw_engine_graph_example.py} (100%) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 6fd067a20f..7bf7dd6b6d 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -283,14 +283,12 @@ void TRTEngine::enable_profiling() { void TRTEngine::set_profile_format(std::string format) { if (format == "trex") { - profile_format = TraceFormat::kTREX; + this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX); } else if (format == "perfetto") { - profile_format = TraceFormat::kPERFETTO; + this->trt_engine_profiler->set_profile_format(TraceFormat::kPERFETTO); } else { TORCHTRT_THROW_ERROR("Invalid profile format: " + format); } - - profile_format = profile_format; } std::string TRTEngine::get_engine_layer_info() { diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 23bada84cd..15d723ce4e 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -192,7 +192,6 @@ struct TRTEngine : torch::CustomClassHolder { #else bool profile_execution = false; #endif - TraceFormat profile_format = TraceFormat::kPERFETTO; std::string device_profile_path; std::string input_profile_path; std::string output_profile_path; diff --git a/core/runtime/TRTEngineProfiler.cpp b/core/runtime/TRTEngineProfiler.cpp index 5996a75e85..261ccc59c5 100644 --- a/core/runtime/TRTEngineProfiler.cpp +++ b/core/runtime/TRTEngineProfiler.cpp @@ -32,7 +32,11 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector< } } -void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format) { +void TRTEngineProfiler::set_profile_format(TraceFormat format) { + this->profile_format = format; +} + +void dump_trace(const std::string& path, const TRTEngineProfiler& value) { std::stringstream out; out << "[" << std::endl; double ts = 0.0; @@ -48,17 +52,17 @@ void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFo out << " {" << std::endl; out << " \"name\": \"" << layer_name << "\"," << std::endl; - if (format == kPERFETTO) { + if (value.profile_format == TraceFormat::kPERFETTO) { out << " \"ph\": \"X\"," << std::endl; out << " \"ts\": " << running_time * 1000 << "," << std::endl; out << " \"dur\": " << elem.time * 1000 << "," << std::endl; out << " \"tid\": 1," << std::endl; out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl; + out << " \"args\": {}" << std::endl; } else { // kTREX out << " \"timeMs\": " << elem.time << "," << std::endl; out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl; out << " \"percentage\": " << (elem.time * 100.0 / ts) << "," << std::endl; - out << " \"args\": {}" << std::endl; } out << " }," << std::endl; running_time += elem.time; diff --git a/core/runtime/TRTEngineProfiler.h b/core/runtime/TRTEngineProfiler.h index 682fa3889d..0ffa0705d1 100644 --- a/core/runtime/TRTEngineProfiler.h +++ b/core/runtime/TRTEngineProfiler.h @@ -19,18 +19,19 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler { float time{0}; int count{0}; }; - + void set_profile_format(TraceFormat format); virtual void reportLayerTime(const char* layerName, float ms) noexcept; TRTEngineProfiler( const std::string& name, const std::vector& srcProfilers = std::vector()); friend std::ostream& operator<<(std::ostream& out, const TRTEngineProfiler& value); - friend void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format); + friend void dump_trace(const std::string& path, const TRTEngineProfiler& value); private: std::string name; std::vector layer_names; std::map profile; + TraceFormat profile_format = TraceFormat::kPERFETTO; }; } // namespace runtime diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index e24ea8df29..64b111750f 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -339,10 +339,7 @@ std::vector execute_engine(std::vector inputs, c10::intr if (compiled_engine->profile_execution) { LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); - dump_trace( - compiled_engine->trt_engine_profile_path, - *compiled_engine->trt_engine_profiler, - compiled_engine->profile_format); + dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); compiled_engine->dump_engine_layer_info(); } @@ -443,10 +440,7 @@ std::vector execute_engine(std::vector inputs, c10::intr if (compiled_engine->profile_execution) { LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); - dump_trace( - compiled_engine->trt_engine_profile_path, - *compiled_engine->trt_engine_profiler, - compiled_engine->profile_format); + dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); compiled_engine->dump_engine_layer_info(); } diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index aa964998be..95f1581881 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -335,7 +335,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: return tuple(outputs) def enable_profiling( - self, profiling_results_dir: Optional[str] = None, profile_format: str = "trex" + self, + profiling_results_dir: Optional[str] = None, + profile_format: str = "perfetto", ) -> None: """Enable the profiler to collect latency information about the execution of the engine diff --git a/tools/debug/engine_visualization/llama_hlo.py b/tools/debug/engine_visualization/draw_engine_graph_example.py similarity index 100% rename from tools/debug/engine_visualization/llama_hlo.py rename to tools/debug/engine_visualization/draw_engine_graph_example.py From 1f2b35e3834f3b090a110b332e984e82308c68cd Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 30 May 2025 22:49:02 +0000 Subject: [PATCH 08/12] Added torchtrt.dynamo.debugger. Cleaning settings.debug --- py/torch_tensorrt/dynamo/__init__.py | 1 + py/torch_tensorrt/dynamo/_compiler.py | 25 ++- py/torch_tensorrt/dynamo/_debugger.py | 177 ++++++++++++++++++ py/torch_tensorrt/dynamo/_defaults.py | 1 - py/torch_tensorrt/dynamo/_refit.py | 7 - py/torch_tensorrt/dynamo/_settings.py | 5 +- .../dynamo/conversion/_TRTInterpreter.py | 1 - .../dynamo/conversion/_conversion.py | 2 - .../partitioning/_adjacency_partitioner.py | 9 +- .../partitioning/_global_partitioner.py | 9 +- .../dynamo/partitioning/common.py | 5 +- tools/debug/engine_visualization/README.md | 4 +- .../draw_engine_graph_example.py | 36 ++-- 13 files changed, 228 insertions(+), 54 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/_debugger.py diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 6fabdad633..675dd0cd53 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -14,6 +14,7 @@ load_cross_compiled_exported_program, save_cross_compiled_exported_program, ) + from ._debugger import Debugger from ._exporter import export from ._refit import refit_module_weights from ._settings import CompilationSettings diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8541734d9..b46c2327aa 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -520,7 +520,13 @@ def compile( """ if debug: - set_log_level(logger.parent, logging.DEBUG) + warnings.warn( + "The 'debug' argument is deprecated and will be removed in a future release. " + "Please use the torch_tensorrt.dynamo.Debugger context manager for debugging and graph capture.", + DeprecationWarning, + stacklevel=2, + ) + if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: raise ValueError( @@ -642,7 +648,6 @@ def compile( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -745,7 +750,7 @@ def compile_module( # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( - gm, settings.debug, settings.torch_executed_ops + gm, settings.torch_executed_ops ) dryrun_tracker.total_ops_in_graph = total_ops @@ -797,7 +802,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: logger.info("Partitioning the graph via the fast partitioner") partitioned_module, supported_ops = partitioning.fast_partition( gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, @@ -818,7 +822,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: logger.info("Partitioning the graph via the global partitioner") partitioned_module, supported_ops = partitioning.global_partition( gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, @@ -925,17 +928,21 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ) trt_modules[name] = trt_module + from torch_tensorrt.dynamo._debugger import ( + DEBUG_FILE_DIR, + SAVE_ENGINE_PROFILE, + ) - if settings.debug and settings.engine_vis_dir: + if SAVE_ENGINE_PROFILE: if settings.use_python_runtime: logger.warning( "Profiling can only be enabled when using the C++ runtime" ) else: - if not os.path.exists(settings.engine_vis_dir): - os.makedirs(settings.engine_vis_dir) + path = os.path.join(DEBUG_FILE_DIR, "engine_visualization") + os.makedirs(path, exist_ok=True) trt_module.enable_profiling( - profiling_results_dir=settings.engine_vis_dir, + profiling_results_dir=path, profile_format="trex", ) diff --git a/py/torch_tensorrt/dynamo/_debugger.py b/py/torch_tensorrt/dynamo/_debugger.py new file mode 100644 index 0000000000..782ac696f3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/_debugger.py @@ -0,0 +1,177 @@ +import logging +import os +import tempfile +from logging.config import dictConfig +from typing import Any, List, Optional + +import torch +from torch_tensorrt.dynamo.lowering import ( + ATEN_POST_LOWERING_PASSES, + ATEN_PRE_LOWERING_PASSES, +) + +_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]") +GRAPH_LEVEL = 5 +logging.addLevelName(GRAPH_LEVEL, "GRAPHS") + +# Debugger States +DEBUG_FILE_DIR = tempfile.TemporaryDirectory().name +SAVE_ENGINE_PROFILE = False + + +class Debugger: + def __init__( + self, + level: str, + capture_fx_graph_before: Optional[List[str]] = None, + capture_fx_graph_after: Optional[List[str]] = None, + save_engine_profile: bool = False, + logging_dir: Optional[str] = None, + ): + + if level != "graphs" and (capture_fx_graph_after or save_engine_profile): + _LOGGER.warning( + "Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'" + ) + + if level == "debug": + self.level = logging.DEBUG + elif level == "info": + self.level = logging.INFO + elif level == "warning": + self.level = logging.WARNING + elif level == "error": + self.level = logging.ERROR + elif level == "internal_errors": + self.level = logging.CRITICAL + elif level == "graphs": + self.level = GRAPH_LEVEL + + else: + raise ValueError( + f"Invalid level: {level}, allowed levels are: debug, info, warning, error, internal_errors, graphs" + ) + + self.capture_fx_graph_before = capture_fx_graph_before + self.capture_fx_graph_after = capture_fx_graph_after + global SAVE_ENGINE_PROFILE + SAVE_ENGINE_PROFILE = save_engine_profile + + if logging_dir is not None: + global DEBUG_FILE_DIR + DEBUG_FILE_DIR = logging_dir + os.makedirs(DEBUG_FILE_DIR, exist_ok=True) + + def __enter__(self) -> None: + self.original_lvl = _LOGGER.getEffectiveLevel() + self.rt_level = torch.ops.tensorrt.get_logging_level() + dictConfig(self.get_config()) + + if self.level == GRAPH_LEVEL: + self.old_pre_passes, self.old_post_passes = ( + ATEN_PRE_LOWERING_PASSES.passes, + ATEN_POST_LOWERING_PASSES.passes, + ) + pre_pass_names = [p.__name__ for p in self.old_pre_passes] + post_pass_names = [p.__name__ for p in self.old_post_passes] + path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization") + if self.capture_fx_graph_before is not None: + pre_vis_passes = [ + p for p in self.capture_fx_graph_before if p in pre_pass_names + ] + post_vis_passes = [ + p for p in self.capture_fx_graph_before if p in post_pass_names + ] + ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(pre_vis_passes, path) + ATEN_POST_LOWERING_PASSES.insert_debug_pass_before( + post_vis_passes, path + ) + if self.capture_fx_graph_after is not None: + pre_vis_passes = [ + p for p in self.capture_fx_graph_after if p in pre_pass_names + ] + post_vis_passes = [ + p for p in self.capture_fx_graph_after if p in post_pass_names + ] + ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path) + ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path) + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + + dictConfig(self.get_default_config()) + torch.ops.tensorrt.set_logging_level(self.rt_level) + if self.level == GRAPH_LEVEL and self.capture_fx_graph_after: + ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = ( + self.old_pre_passes, + self.old_post_passes, + ) + + def get_config(self) -> dict[str, Any]: + config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "brief": { + "format": "%(asctime)s - %(levelname)s - %(message)s", + "datefmt": "%H:%M:%S", + }, + "standard": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "file": { + "level": self.level, + "class": "logging.FileHandler", + "filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log", + "formatter": "standard", + }, + "console": { + "level": self.level, + "class": "logging.StreamHandler", + "formatter": "brief", + }, + }, + "loggers": { + "": { # root logger + "handlers": ["file", "console"], + "level": self.level, + "propagate": True, + }, + }, + "force": True, + } + return config + + def get_default_config(self) -> dict[str, Any]: + config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "brief": { + "format": "%(asctime)s - %(levelname)s - %(message)s", + "datefmt": "%H:%M:%S", + }, + "standard": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "console": { + "level": self.original_lvl, + "class": "logging.StreamHandler", + "formatter": "brief", + }, + }, + "loggers": { + "": { # root logger + "handlers": ["console"], + "level": self.original_lvl, + "propagate": True, + }, + }, + "force": True, + } + return config diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 6bea10171a..aafd1072f4 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -15,7 +15,6 @@ DLA_SRAM_SIZE = 1048576 ENGINE_CAPABILITY = EngineCapability.STANDARD WORKSPACE_SIZE = 0 -ENGINE_VIS_DIR = None MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False MAX_AUX_STREAMS = None diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 7be7e0f16c..7e559e2649 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -39,7 +39,6 @@ check_module_output, get_model_device, get_torch_inputs, - set_log_level, to_torch_device, to_torch_tensorrt_device, ) @@ -72,7 +71,6 @@ def construct_refit_mapping( interpreter = TRTInterpreter( module, inputs, - logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, compilation_settings=settings, ) @@ -266,9 +264,6 @@ def refit_module_weights( not settings.immutable_weights ), "Refitting is not enabled. Please recompile the engine with immutable_weights=False." - if settings.debug: - set_log_level(logger.parent, logging.DEBUG) - device = to_torch_tensorrt_device(settings.device) if arg_inputs: if not isinstance(arg_inputs, collections.abc.Sequence): @@ -304,7 +299,6 @@ def refit_module_weights( try: new_partitioned_module, supported_ops = partitioning.fast_partition( new_gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) @@ -320,7 +314,6 @@ def refit_module_weights( if not settings.use_fast_partitioner: new_partitioned_module, supported_ops = partitioning.global_partition( new_gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index da56a780d4..8a583b63e3 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Collection, Optional, Set, Tuple, Union @@ -7,7 +8,6 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, - DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -18,7 +18,6 @@ ENABLE_WEIGHT_STREAMING, ENABLED_PRECISIONS, ENGINE_CAPABILITY, - ENGINE_VIS_DIR, HARDWARE_COMPATIBLE, IMMUTABLE_WEIGHTS, L2_LIMIT_FOR_TILING, @@ -102,7 +101,7 @@ class CompilationSettings: """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) - debug: bool = DEBUG + debug: bool = logging.root.manager.root.level <= logging.DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Collection[Target] = field(default_factory=set) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 39a1ed957d..930511666a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -75,7 +75,6 @@ def __init__( self, module: torch.fx.GraphModule, input_specs: Sequence[Input], - logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index adb7039e7e..35b6c26617 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,7 +3,6 @@ import logging from typing import Any, List, Optional, Sequence -import tensorrt as trt import torch from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -60,7 +59,6 @@ def interpret_module_to_result( interpreter = TRTInterpreter( module, inputs, - logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, compilation_settings=settings, engine_cache=engine_cache, diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 429de3ffbb..2cb7fe43f5 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -13,14 +13,15 @@ ) from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet from torch_tensorrt.dynamo._defaults import ( - DEBUG, MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + ConverterRegistry, +) logger = logging.getLogger(__name__) @@ -250,7 +251,6 @@ def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: def partition( gm: torch.fx.GraphModule, - verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, @@ -286,7 +286,6 @@ def partition( partitioned_graph = partitioner.partition_graph() - if verbose: - supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) + supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) return partitioned_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index bdca0e1e1d..3279db00cf 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -7,14 +7,15 @@ from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupport, SupportDict from torch_tensorrt.dynamo._defaults import ( - DEBUG, MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + ConverterRegistry, +) logger = logging.getLogger(__name__) @@ -200,7 +201,6 @@ def print_support_overview( def partition( gm: torch.fx.GraphModule, - verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, @@ -229,7 +229,6 @@ def partition( # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_") - if verbose: - supported_ops.print_support_overview(len(partitions)) + supported_ops.print_support_overview(len(partitions)) return fused_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 685ec6ebef..e499e988a9 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -5,7 +5,6 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import DEBUG from torch_tensorrt.dynamo.utils import contains_sym_int, extract_var_range_info logger = logging.getLogger(__name__) @@ -169,7 +168,6 @@ def get_submodule_io( def get_graph_converter_support( graph_module: torch.fx.GraphModule, - verbose: bool = DEBUG, torch_executed_ops: Optional[Set[str]] = None, ) -> Tuple[int, int]: """Helper function to get converter support overview pre-partitioning @@ -199,7 +197,6 @@ def get_graph_converter_support( number_of_supported_nodes += 1 # Print node support overview prior to partitioning - if verbose: - op_support.print_support_overview(print_node_support=True) + op_support.print_support_overview(print_node_support=True) return number_of_supported_nodes, total_functional_nodes diff --git a/tools/debug/engine_visualization/README.md b/tools/debug/engine_visualization/README.md index cacc5543ab..40147cb17c 100644 --- a/tools/debug/engine_visualization/README.md +++ b/tools/debug/engine_visualization/README.md @@ -3,9 +3,7 @@ We use the TRT Engine Explorer (TREX) to visualize the engien graph structure. T ## Installation ```bash -git clone https://github.com/NVIDIA/TensorRT.git -cd TensorRT/tools/experimental/trt-engine-explorer -python3 -m pip install -e .[notebook] +pip install git+https://github.com/NVIDIA/TensorRT.git#subdirectory=tools/experimental/trt-engine-explorer sudo apt --yes install graphviz ``` diff --git a/tools/debug/engine_visualization/draw_engine_graph_example.py b/tools/debug/engine_visualization/draw_engine_graph_example.py index 2d11ad0753..490cb060b5 100644 --- a/tools/debug/engine_visualization/draw_engine_graph_example.py +++ b/tools/debug/engine_visualization/draw_engine_graph_example.py @@ -1,3 +1,6 @@ +import logging +import os + import numpy as np import torch import torch_tensorrt as torch_tensorrt @@ -7,22 +10,27 @@ model = models.resnet18(pretrained=False).eval().to("cuda") exp_program = torch.export.export(model, tuple(inputs)) enabled_precisions = {torch.float} -debug = False workspace_size = 20 << 30 -min_block_size = 0 +# min_block_size = 0 use_python_runtime = False torch_executed_ops = {} -trt_gm = torch_tensorrt.dynamo.compile( - exp_program, - inputs=inputs, - enabled_precisions=enabled_precisions, - truncate_double=True, - debug=True, - use_python_runtime=False, - engine_vis_dir="/home/profile", -) -trt_output = trt_gm(*inputs) +logging_dir = "/home/profile" +with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=logging_dir, + capture_fx_graph_after=["constant_fold"], + save_engine_profile=True, +): + trt_gm = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + enabled_precisions=enabled_precisions, + truncate_double=True, + use_python_runtime=False, + ) + trt_output = trt_gm(*inputs) -from draw_engine_graph import draw_engine + from draw_engine_graph import draw_engine -draw_engine("/home/profile") + draw_engine(os.path.join(logging_dir, "engine_visualization")) +print() From ec390e55b1e364f167e5c9e3989faa975eb5f08a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 3 Jun 2025 04:58:01 +0000 Subject: [PATCH 09/12] Revert to debug flag --- .../dynamo/{_debugger.py => Debugger.py} | 58 +++++++++---------- py/torch_tensorrt/dynamo/_compiler.py | 27 +-------- py/torch_tensorrt/dynamo/_settings.py | 4 +- 3 files changed, 29 insertions(+), 60 deletions(-) rename py/torch_tensorrt/dynamo/{_debugger.py => Debugger.py} (78%) diff --git a/py/torch_tensorrt/dynamo/_debugger.py b/py/torch_tensorrt/dynamo/Debugger.py similarity index 78% rename from py/torch_tensorrt/dynamo/_debugger.py rename to py/torch_tensorrt/dynamo/Debugger.py index 782ac696f3..af79bc69ab 100644 --- a/py/torch_tensorrt/dynamo/_debugger.py +++ b/py/torch_tensorrt/dynamo/Debugger.py @@ -14,67 +14,60 @@ GRAPH_LEVEL = 5 logging.addLevelName(GRAPH_LEVEL, "GRAPHS") -# Debugger States -DEBUG_FILE_DIR = tempfile.TemporaryDirectory().name -SAVE_ENGINE_PROFILE = False - class Debugger: def __init__( self, - level: str, + log_level: str, capture_fx_graph_before: Optional[List[str]] = None, capture_fx_graph_after: Optional[List[str]] = None, save_engine_profile: bool = False, logging_dir: Optional[str] = None, ): - - if level != "graphs" and (capture_fx_graph_after or save_engine_profile): + self.debug_file_dir = tempfile.TemporaryDirectory().name + if log_level != "graphs" and (capture_fx_graph_after or save_engine_profile): _LOGGER.warning( "Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'" ) - if level == "debug": - self.level = logging.DEBUG - elif level == "info": - self.level = logging.INFO - elif level == "warning": - self.level = logging.WARNING - elif level == "error": - self.level = logging.ERROR - elif level == "internal_errors": - self.level = logging.CRITICAL - elif level == "graphs": - self.level = GRAPH_LEVEL + if log_level == "debug": + self.log_level = logging.DEBUG + elif log_level == "info": + self.log_level = logging.INFO + elif log_level == "warning": + self.log_level = logging.WARNING + elif log_level == "error": + self.log_level = logging.ERROR + elif log_level == "internal_errors": + self.log_level = logging.CRITICAL + elif log_level == "graphs": + self.log_level = GRAPH_LEVEL else: raise ValueError( - f"Invalid level: {level}, allowed levels are: debug, info, warning, error, internal_errors, graphs" + f"Invalid level: {log_level}, allowed levels are: debug, info, warning, error, internal_errors, graphs" ) self.capture_fx_graph_before = capture_fx_graph_before self.capture_fx_graph_after = capture_fx_graph_after - global SAVE_ENGINE_PROFILE - SAVE_ENGINE_PROFILE = save_engine_profile if logging_dir is not None: - global DEBUG_FILE_DIR - DEBUG_FILE_DIR = logging_dir - os.makedirs(DEBUG_FILE_DIR, exist_ok=True) + self.debug_file_dir = logging_dir + os.makedirs(self.debug_file_dir, exist_ok=True) def __enter__(self) -> None: self.original_lvl = _LOGGER.getEffectiveLevel() self.rt_level = torch.ops.tensorrt.get_logging_level() dictConfig(self.get_config()) - if self.level == GRAPH_LEVEL: + if self.log_level == GRAPH_LEVEL: self.old_pre_passes, self.old_post_passes = ( ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes, ) pre_pass_names = [p.__name__ for p in self.old_pre_passes] post_pass_names = [p.__name__ for p in self.old_post_passes] - path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization") + path = os.path.join(self.debug_file_dir, "lowering_passes_visualization") if self.capture_fx_graph_before is not None: pre_vis_passes = [ p for p in self.capture_fx_graph_before if p in pre_pass_names @@ -100,11 +93,12 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: dictConfig(self.get_default_config()) torch.ops.tensorrt.set_logging_level(self.rt_level) - if self.level == GRAPH_LEVEL and self.capture_fx_graph_after: + if self.log_level == GRAPH_LEVEL and self.capture_fx_graph_after: ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = ( self.old_pre_passes, self.old_post_passes, ) + self.debug_file_dir = tempfile.TemporaryDirectory().name def get_config(self) -> dict[str, Any]: config = { @@ -122,13 +116,13 @@ def get_config(self) -> dict[str, Any]: }, "handlers": { "file": { - "level": self.level, + "level": self.log_level, "class": "logging.FileHandler", - "filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log", + "filename": f"{self.debug_file_dir}/torch_tensorrt_logging.log", "formatter": "standard", }, "console": { - "level": self.level, + "level": self.log_level, "class": "logging.StreamHandler", "formatter": "brief", }, @@ -136,7 +130,7 @@ def get_config(self) -> dict[str, Any]: "loggers": { "": { # root logger "handlers": ["file", "console"], - "level": self.level, + "level": self.log_level, "propagate": True, }, }, diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index b46c2327aa..a849c6501b 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,7 +2,6 @@ import collections.abc import logging -import os import platform import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union @@ -519,14 +518,6 @@ def compile( torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT """ - if debug: - warnings.warn( - "The 'debug' argument is deprecated and will be removed in a future release. " - "Please use the torch_tensorrt.dynamo.Debugger context manager for debugging and graph capture.", - DeprecationWarning, - stacklevel=2, - ) - if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: raise ValueError( @@ -648,6 +639,7 @@ def compile( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), + "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -928,23 +920,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ) trt_modules[name] = trt_module - from torch_tensorrt.dynamo._debugger import ( - DEBUG_FILE_DIR, - SAVE_ENGINE_PROFILE, - ) - - if SAVE_ENGINE_PROFILE: - if settings.use_python_runtime: - logger.warning( - "Profiling can only be enabled when using the C++ runtime" - ) - else: - path = os.path.join(DEBUG_FILE_DIR, "engine_visualization") - os.makedirs(path, exist_ok=True) - trt_module.enable_profiling( - profiling_results_dir=path, - profile_format="trex", - ) # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 8a583b63e3..97c02f34fb 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,4 +1,3 @@ -import logging from dataclasses import dataclass, field from typing import Collection, Optional, Set, Tuple, Union @@ -8,6 +7,7 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -101,7 +101,7 @@ class CompilationSettings: """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) - debug: bool = logging.root.manager.root.level <= logging.DEBUG + debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Collection[Target] = field(default_factory=set) From d6aa8a4dd080c44a8e468b9baa22cb5fcb000231 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 3 Jun 2025 22:01:05 +0000 Subject: [PATCH 10/12] Fixed the comments --- .../dynamo/{Debugger.py => _Debugger.py} | 12 ++++-------- py/torch_tensorrt/dynamo/__init__.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) rename py/torch_tensorrt/dynamo/{Debugger.py => _Debugger.py} (93%) diff --git a/py/torch_tensorrt/dynamo/Debugger.py b/py/torch_tensorrt/dynamo/_Debugger.py similarity index 93% rename from py/torch_tensorrt/dynamo/Debugger.py rename to py/torch_tensorrt/dynamo/_Debugger.py index af79bc69ab..2b92e1fa51 100644 --- a/py/torch_tensorrt/dynamo/Debugger.py +++ b/py/torch_tensorrt/dynamo/_Debugger.py @@ -25,10 +25,6 @@ def __init__( logging_dir: Optional[str] = None, ): self.debug_file_dir = tempfile.TemporaryDirectory().name - if log_level != "graphs" and (capture_fx_graph_after or save_engine_profile): - _LOGGER.warning( - "Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'" - ) if log_level == "debug": self.log_level = logging.DEBUG @@ -60,7 +56,7 @@ def __enter__(self) -> None: self.rt_level = torch.ops.tensorrt.get_logging_level() dictConfig(self.get_config()) - if self.log_level == GRAPH_LEVEL: + if self.capture_fx_graph_before or self.capture_fx_graph_after: self.old_pre_passes, self.old_post_passes = ( ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes, @@ -93,14 +89,14 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: dictConfig(self.get_default_config()) torch.ops.tensorrt.set_logging_level(self.rt_level) - if self.log_level == GRAPH_LEVEL and self.capture_fx_graph_after: + if self.capture_fx_graph_before or self.capture_fx_graph_after: ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = ( self.old_pre_passes, self.old_post_passes, ) self.debug_file_dir = tempfile.TemporaryDirectory().name - def get_config(self) -> dict[str, Any]: + def get_customized_logging_config(self) -> dict[str, Any]: config = { "version": 1, "disable_existing_loggers": False, @@ -138,7 +134,7 @@ def get_config(self) -> dict[str, Any]: } return config - def get_default_config(self) -> dict[str, Any]: + def get_default_logging_config(self) -> dict[str, Any]: config = { "version": 1, "disable_existing_loggers": False, diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 675dd0cd53..15a17a4f02 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -14,7 +14,7 @@ load_cross_compiled_exported_program, save_cross_compiled_exported_program, ) - from ._debugger import Debugger + from ._Debugger import Debugger from ._exporter import export from ._refit import refit_module_weights from ._settings import CompilationSettings From 5683644d1d0d2820e49415a01b3a5639aa9f49e3 Mon Sep 17 00:00:00 2001 From: Adrian Wang <123616592+cehongwang@users.noreply.github.com> Date: Thu, 5 Jun 2025 21:49:19 -0700 Subject: [PATCH 11/12] Changed the debug setting (#3551) --- py/torch_tensorrt/dynamo/__init__.py | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 62 ++++++++++++-- py/torch_tensorrt/dynamo/_defaults.py | 1 - py/torch_tensorrt/dynamo/_settings.py | 2 - py/torch_tensorrt/dynamo/_tracer.py | 8 +- .../dynamo/conversion/_TRTInterpreter.py | 12 ++- .../dynamo/{ => debug}/_Debugger.py | 81 ++++++++++++++++--- .../dynamo/debug/_DebuggerConfig.py | 12 +++ .../dynamo/debug/_supports_debugger.py | 17 ++++ .../runtime/_MutableTorchTensorRTModule.py | 11 ++- .../runtime/_PythonTorchTensorRTModule.py | 11 ++- 11 files changed, 183 insertions(+), 36 deletions(-) rename py/torch_tensorrt/dynamo/{ => debug}/_Debugger.py (65%) create mode 100644 py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py create mode 100644 py/torch_tensorrt/dynamo/debug/_supports_debugger.py diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 15a17a4f02..607dca76bf 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -14,9 +14,9 @@ load_cross_compiled_exported_program, save_cross_compiled_exported_program, ) - from ._Debugger import Debugger from ._exporter import export from ._refit import refit_module_weights from ._settings import CompilationSettings from ._SourceIR import SourceIR from ._tracer import trace + from .debug._Debugger import Debugger diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a849c6501b..329f72785c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,6 +2,7 @@ import collections.abc import logging +import os import platform import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union @@ -31,6 +32,8 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -42,7 +45,6 @@ get_output_metadata, parse_graph_io, prepare_inputs, - set_log_level, to_torch_device, to_torch_tensorrt_device, ) @@ -64,7 +66,7 @@ def cross_compile_for_windows( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = _defaults.DEBUG, + debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -186,7 +188,11 @@ def cross_compile_for_windows( ) if debug: - set_log_level(logger.parent, logging.DEBUG) + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + DeprecationWarning, + stacklevel=2, + ) if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: @@ -297,7 +303,6 @@ def cross_compile_for_windows( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -399,7 +404,7 @@ def compile( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = _defaults.DEBUG, + debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -518,6 +523,13 @@ def compile( torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT """ + if debug: + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality", + DeprecationWarning, + stacklevel=2, + ) + if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: raise ValueError( @@ -639,7 +651,6 @@ def compile( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -713,12 +724,15 @@ def compile( return trt_gm +@fn_supports_debugger def compile_module( gm: torch.fx.GraphModule, sample_arg_inputs: Sequence[Input], sample_kwarg_inputs: Optional[dict[Any, Any]] = None, settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, + *, + _debugger_settings: Optional[DebuggerConfig] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -921,6 +935,34 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules[name] = trt_module + if _debugger_settings: + + if _debugger_settings.save_engine_profile: + if settings.use_python_runtime: + if _debugger_settings.profile_format == "trex": + logger.warning( + "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization." + ) + trt_module.enable_profiling() + else: + path = os.path.join( + _debugger_settings.logging_dir, "engine_visualization" + ) + os.makedirs(path, exist_ok=True) + trt_module.enable_profiling( + profiling_results_dir=path, + profile_format=_debugger_settings.profile_format, + ) + + if _debugger_settings.save_layer_info: + with open( + os.path.join( + _debugger_settings.logging_dir, "engine_layer_info.json" + ), + "w", + ) as f: + f.write(trt_module.get_layer_info()) + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) @@ -948,7 +990,7 @@ def convert_exported_program_to_serialized_trt_engine( enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, - debug: bool = _defaults.DEBUG, + debug: bool = False, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, workspace_size: int = _defaults.WORKSPACE_SIZE, min_block_size: int = _defaults.MIN_BLOCK_SIZE, @@ -1051,7 +1093,11 @@ def convert_exported_program_to_serialized_trt_engine( bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ if debug: - set_log_level(logger.parent, logging.DEBUG) + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + DeprecationWarning, + stacklevel=2, + ) if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index aafd1072f4..226372a776 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -6,7 +6,6 @@ from torch_tensorrt._enums import EngineCapability, dtype ENABLED_PRECISIONS = {dtype.f32} -DEBUG = False DEVICE = None DISABLE_TF32 = False ASSUME_DYNAMIC_SHAPE_SUPPORT = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 97c02f34fb..7ac77cccae 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,7 +7,6 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, - DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -101,7 +100,6 @@ class CompilationSettings: """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) - debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Collection[Target] = field(default_factory=set) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 78f7989777..5f4bdd0a8d 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -7,8 +7,8 @@ import torch from torch.export import Dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import DEBUG, default_device -from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device +from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device logger = logging.getLogger(__name__) @@ -70,10 +70,6 @@ def trace( if kwarg_inputs is None: kwarg_inputs = {} - debug = kwargs.get("debug", DEBUG) - if debug: - set_log_level(logger.parent, logging.DEBUG) - device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 930511666a..4217d99232 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -46,6 +46,8 @@ to_torch, ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -70,6 +72,7 @@ class TRTInterpreterResult(NamedTuple): requires_output_allocator: bool +@cls_supports_debugger class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] def __init__( self, @@ -78,12 +81,14 @@ def __init__( output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, + *, + _debugger_settings: Optional[DebuggerConfig] = None, ): super().__init__(module) self.logger = TRT_LOGGER self.builder = trt.Builder(self.logger) - + self._debugger_settings = _debugger_settings flag = 0 if compilation_settings.use_explicit_typing: STRONGLY_TYPED = 1 << (int)( @@ -204,7 +209,7 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() - if self.compilation_settings.debug: + if self._debugger_settings and self._debugger_settings.engine_builder_monitor: builder_config.progress_monitor = TRTBulderMonitor() if self.compilation_settings.workspace_size != 0: @@ -215,7 +220,8 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.DETAILED - if self.compilation_settings.debug + if self._debugger_settings + and self._debugger_settings.save_engine_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) diff --git a/py/torch_tensorrt/dynamo/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py similarity index 65% rename from py/torch_tensorrt/dynamo/_Debugger.py rename to py/torch_tensorrt/dynamo/debug/_Debugger.py index 2b92e1fa51..bb9dffbfc1 100644 --- a/py/torch_tensorrt/dynamo/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -1,10 +1,18 @@ +import contextlib +import functools import logging import os import tempfile from logging.config import dictConfig from typing import Any, List, Optional +from unittest import mock import torch +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import ( + _DEBUG_ENABLED_CLS, + _DEBUG_ENABLED_FUNCS, +) from torch_tensorrt.dynamo.lowering import ( ATEN_POST_LOWERING_PASSES, ATEN_PRE_LOWERING_PASSES, @@ -18,13 +26,47 @@ class Debugger: def __init__( self, - log_level: str, + log_level: str = "debug", capture_fx_graph_before: Optional[List[str]] = None, capture_fx_graph_after: Optional[List[str]] = None, save_engine_profile: bool = False, - logging_dir: Optional[str] = None, + profile_format: str = "perfetto", + engine_builder_monitor: bool = True, + logging_dir: str = tempfile.gettempdir(), + save_layer_info: bool = False, ): - self.debug_file_dir = tempfile.TemporaryDirectory().name + """Initialize a debugger for TensorRT conversion. + + Args: + log_level (str): Logging level to use. Valid options are: + 'debug', 'info', 'warning', 'error', 'internal_errors', 'graphs'. + Defaults to 'debug'. + capture_fx_graph_before (List[str], optional): List of pass names to visualize FX graph + before execution of a lowering pass. Defaults to None. + capture_fx_graph_after (List[str], optional): List of pass names to visualize FX graph + after execution of a lowering pass. Defaults to None. + save_engine_profile (bool): Whether to save TensorRT engine profiling information. + Defaults to False. + profile_format (str): Format for profiling data. Can be either 'perfetto' or 'trex'. + If you need to generate engine graph using the profiling files, set it to 'trex' . + Defaults to 'perfetto'. + engine_builder_monitor (bool): Whether to monitor TensorRT engine building process. + Defaults to True. + logging_dir (str): Directory to save debug logs and profiles. + Defaults to system temp directory. + save_layer_info (bool): Whether to save layer info. + Defaults to False. + """ + + os.makedirs(logging_dir, exist_ok=True) + self.cfg = DebuggerConfig( + log_level=log_level, + save_engine_profile=save_engine_profile, + engine_builder_monitor=engine_builder_monitor, + logging_dir=logging_dir, + profile_format=profile_format, + save_layer_info=save_layer_info, + ) if log_level == "debug": self.log_level = logging.DEBUG @@ -47,14 +89,10 @@ def __init__( self.capture_fx_graph_before = capture_fx_graph_before self.capture_fx_graph_after = capture_fx_graph_after - if logging_dir is not None: - self.debug_file_dir = logging_dir - os.makedirs(self.debug_file_dir, exist_ok=True) - def __enter__(self) -> None: self.original_lvl = _LOGGER.getEffectiveLevel() self.rt_level = torch.ops.tensorrt.get_logging_level() - dictConfig(self.get_config()) + dictConfig(self.get_customized_logging_config()) if self.capture_fx_graph_before or self.capture_fx_graph_after: self.old_pre_passes, self.old_post_passes = ( @@ -63,7 +101,7 @@ def __enter__(self) -> None: ) pre_pass_names = [p.__name__ for p in self.old_pre_passes] post_pass_names = [p.__name__ for p in self.old_post_passes] - path = os.path.join(self.debug_file_dir, "lowering_passes_visualization") + path = os.path.join(self.cfg.logging_dir, "lowering_passes_visualization") if self.capture_fx_graph_before is not None: pre_vis_passes = [ p for p in self.capture_fx_graph_before if p in pre_pass_names @@ -85,9 +123,25 @@ def __enter__(self) -> None: ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path) ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path) + self._context_stack = contextlib.ExitStack() + + for f in _DEBUG_ENABLED_FUNCS: + f.__kwdefaults__["_debugger_settings"] = self.cfg + + [ + self._context_stack.enter_context( + mock.patch.object( + c, + "__init__", + functools.partialmethod(c.__init__, _debugger_settings=self.cfg), + ) + ) + for c in _DEBUG_ENABLED_CLS + ] + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - dictConfig(self.get_default_config()) + dictConfig(self.get_default_logging_config()) torch.ops.tensorrt.set_logging_level(self.rt_level) if self.capture_fx_graph_before or self.capture_fx_graph_after: ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = ( @@ -96,6 +150,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: ) self.debug_file_dir = tempfile.TemporaryDirectory().name + for f in _DEBUG_ENABLED_FUNCS: + f.__kwdefaults__["_debugger_settings"] = None + + self._context_stack.close() + def get_customized_logging_config(self) -> dict[str, Any]: config = { "version": 1, @@ -114,7 +173,7 @@ def get_customized_logging_config(self) -> dict[str, Any]: "file": { "level": self.log_level, "class": "logging.FileHandler", - "filename": f"{self.debug_file_dir}/torch_tensorrt_logging.log", + "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", }, "console": { diff --git a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py new file mode 100644 index 0000000000..3c409b0aa8 --- /dev/null +++ b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py @@ -0,0 +1,12 @@ +import tempfile +from dataclasses import dataclass + + +@dataclass +class DebuggerConfig: + log_level: str = "debug" + save_engine_profile: bool = False + engine_builder_monitor: bool = True + logging_dir: str = tempfile.gettempdir() + profile_format: str = "perfetto" + save_layer_info: bool = False diff --git a/py/torch_tensorrt/dynamo/debug/_supports_debugger.py b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py new file mode 100644 index 0000000000..2d9fd2a149 --- /dev/null +++ b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py @@ -0,0 +1,17 @@ +from typing import Any, Callable, Type, TypeVar + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) + +_DEBUG_ENABLED_FUNCS = [] +_DEBUG_ENABLED_CLS = [] + + +def fn_supports_debugger(func: F) -> F: + _DEBUG_ENABLED_FUNCS.append(func) + return func + + +def cls_supports_debugger(cls: Type[T]) -> Type[T]: + _DEBUG_ENABLED_CLS.append(cls) + return cls diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index eaeb6a8c28..c6bd22f938 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,5 +1,6 @@ import inspect import logging +import warnings from copy import deepcopy from enum import Enum, auto from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union @@ -71,7 +72,7 @@ def __init__( ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, immutable_weights: bool = False, - debug: bool = _defaults.DEBUG, + debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -109,7 +110,6 @@ def __init__( sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. - debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -156,6 +156,12 @@ def __init__( self.kwarg_inputs: dict[str, Any] = {} device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} + if debug: + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + DeprecationWarning, + stacklevel=2, + ) assert ( not immutable_weights ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." @@ -165,7 +171,6 @@ def __init__( if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6415ce11c3..8d1a31564d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -12,6 +12,8 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -111,6 +113,7 @@ def set_runtime_states( ) +@cls_supports_debugger class PythonTorchTensorRTModule(Module): # type: ignore[misc] """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. @@ -128,6 +131,7 @@ def __init__( settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, + _debugger_settings: Optional[DebuggerConfig] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -157,6 +161,7 @@ def __init__( """ self.context: Any + self._debugger_settings: Optional[DebuggerConfig] = _debugger_settings super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) @@ -193,7 +198,11 @@ def __init__( self.target_device_properties = torch.cuda.get_device_properties( self.target_device_id ) - self.profiling_enabled = settings.debug if settings.debug is not None else False + self.profiling_enabled = ( + _debugger_settings.save_engine_profile + if _debugger_settings is not None + else False + ) self.settings = settings self.engine = None self.weight_name_map = weight_name_map From fb5dc81a7260d243320578177ada1b6ad38aff0c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 6 Jun 2025 06:36:36 +0000 Subject: [PATCH 12/12] Fixed the comments --- core/runtime/TRTEngineProfiler.h | 2 - py/torch_tensorrt/dynamo/_compiler.py | 36 ++++------ py/torch_tensorrt/dynamo/_defaults.py | 1 + .../dynamo/conversion/_TRTInterpreter.py | 11 ++- py/torch_tensorrt/dynamo/debug/_Debugger.py | 68 ++++++------------- .../dynamo/debug/_DebuggerConfig.py | 5 +- .../dynamo/lowering/passes/pass_manager.py | 6 +- .../runtime/_PythonTorchTensorRTModule.py | 8 +-- tools/debug/engine_visualization/README.md | 4 +- .../draw_engine_graph_example.py | 17 ++--- 10 files changed, 61 insertions(+), 97 deletions(-) diff --git a/core/runtime/TRTEngineProfiler.h b/core/runtime/TRTEngineProfiler.h index 0ffa0705d1..6691f2e81d 100644 --- a/core/runtime/TRTEngineProfiler.h +++ b/core/runtime/TRTEngineProfiler.h @@ -12,8 +12,6 @@ namespace runtime { enum TraceFormat { kPERFETTO, kTREX }; -// Forward declare the function - struct TRTEngineProfiler : public nvinfer1::IProfiler { struct Record { float time{0}; diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 329f72785c..30c5b7b332 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -66,7 +66,6 @@ def cross_compile_for_windows( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -140,7 +139,6 @@ def cross_compile_for_windows( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -187,9 +185,9 @@ def cross_compile_for_windows( f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}" ) - if debug: + if kwargs.get("debug", False): warnings.warn( - "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + "`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality.", DeprecationWarning, stacklevel=2, ) @@ -404,7 +402,6 @@ def compile( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -480,7 +477,6 @@ def compile( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -523,9 +519,9 @@ def compile( torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT """ - if debug: + if kwargs.get("debug", False): warnings.warn( - "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality", + "`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality", DeprecationWarning, stacklevel=2, ) @@ -732,7 +728,7 @@ def compile_module( settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, *, - _debugger_settings: Optional[DebuggerConfig] = None, + _debugger_config: Optional[DebuggerConfig] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -935,29 +931,30 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules[name] = trt_module - if _debugger_settings: + if _debugger_config: - if _debugger_settings.save_engine_profile: + if _debugger_config.save_engine_profile: if settings.use_python_runtime: - if _debugger_settings.profile_format == "trex": + if _debugger_config.profile_format == "trex": logger.warning( "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization." ) trt_module.enable_profiling() else: path = os.path.join( - _debugger_settings.logging_dir, "engine_visualization" + _debugger_config.logging_dir, + "engine_visualization_profile", ) os.makedirs(path, exist_ok=True) trt_module.enable_profiling( profiling_results_dir=path, - profile_format=_debugger_settings.profile_format, + profile_format=_debugger_config.profile_format, ) - if _debugger_settings.save_layer_info: + if _debugger_config.save_layer_info: with open( os.path.join( - _debugger_settings.logging_dir, "engine_layer_info.json" + _debugger_config.logging_dir, "engine_layer_info.json" ), "w", ) as f: @@ -990,7 +987,6 @@ def convert_exported_program_to_serialized_trt_engine( enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, - debug: bool = False, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, workspace_size: int = _defaults.WORKSPACE_SIZE, min_block_size: int = _defaults.MIN_BLOCK_SIZE, @@ -1052,7 +1048,6 @@ def convert_exported_program_to_serialized_trt_engine( torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use - debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage @@ -1092,9 +1087,9 @@ def convert_exported_program_to_serialized_trt_engine( Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ - if debug: + if kwargs.get("debug", False): warnings.warn( - "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + "`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality.", DeprecationWarning, stacklevel=2, ) @@ -1181,7 +1176,6 @@ def convert_exported_program_to_serialized_trt_engine( compilation_options = { "assume_dynamic_shape_support": assume_dynamic_shape_support, "enabled_precisions": enabled_precisions, - "debug": debug, "workspace_size": workspace_size, "min_block_size": min_block_size, "torch_executed_ops": torch_executed_ops, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 226372a776..5c0c8a2bc8 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -49,6 +49,7 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +DEBUG_LOGGING_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt/debug_logs") def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 4217d99232..38fd0e6eda 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -45,9 +45,9 @@ get_trt_tensor, to_torch, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -82,13 +82,13 @@ def __init__( compilation_settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, *, - _debugger_settings: Optional[DebuggerConfig] = None, + _debugger_config: Optional[DebuggerConfig] = None, ): super().__init__(module) self.logger = TRT_LOGGER self.builder = trt.Builder(self.logger) - self._debugger_settings = _debugger_settings + self._debugger_config = _debugger_config flag = 0 if compilation_settings.use_explicit_typing: STRONGLY_TYPED = 1 << (int)( @@ -209,7 +209,7 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() - if self._debugger_settings and self._debugger_settings.engine_builder_monitor: + if self._debugger_config and self._debugger_config.engine_builder_monitor: builder_config.progress_monitor = TRTBulderMonitor() if self.compilation_settings.workspace_size != 0: @@ -220,8 +220,7 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.DETAILED - if self._debugger_settings - and self._debugger_settings.save_engine_profile + if self._debugger_config and self._debugger_config.save_engine_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index bb9dffbfc1..4c2a402b15 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -8,6 +8,7 @@ from unittest import mock import torch +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import ( _DEBUG_ENABLED_CLS, @@ -32,7 +33,7 @@ def __init__( save_engine_profile: bool = False, profile_format: str = "perfetto", engine_builder_monitor: bool = True, - logging_dir: str = tempfile.gettempdir(), + logging_dir: str = DEBUG_LOGGING_DIR, save_layer_info: bool = False, ): """Initialize a debugger for TensorRT conversion. @@ -92,7 +93,7 @@ def __init__( def __enter__(self) -> None: self.original_lvl = _LOGGER.getEffectiveLevel() self.rt_level = torch.ops.tensorrt.get_logging_level() - dictConfig(self.get_customized_logging_config()) + dictConfig(self.get_logging_config(self.log_level)) if self.capture_fx_graph_before or self.capture_fx_graph_after: self.old_pre_passes, self.old_post_passes = ( @@ -126,14 +127,14 @@ def __enter__(self) -> None: self._context_stack = contextlib.ExitStack() for f in _DEBUG_ENABLED_FUNCS: - f.__kwdefaults__["_debugger_settings"] = self.cfg + f.__kwdefaults__["_debugger_config"] = self.cfg [ self._context_stack.enter_context( mock.patch.object( c, "__init__", - functools.partialmethod(c.__init__, _debugger_settings=self.cfg), + functools.partialmethod(c.__init__, _debugger_config=self.cfg), ) ) for c in _DEBUG_ENABLED_CLS @@ -141,7 +142,7 @@ def __enter__(self) -> None: def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - dictConfig(self.get_default_logging_config()) + dictConfig(self.get_logging_config(None)) torch.ops.tensorrt.set_logging_level(self.rt_level) if self.capture_fx_graph_before or self.capture_fx_graph_after: ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = ( @@ -151,50 +152,13 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: self.debug_file_dir = tempfile.TemporaryDirectory().name for f in _DEBUG_ENABLED_FUNCS: - f.__kwdefaults__["_debugger_settings"] = None + f.__kwdefaults__["_debugger_config"] = None self._context_stack.close() - def get_customized_logging_config(self) -> dict[str, Any]: - config = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "brief": { - "format": "%(asctime)s - %(levelname)s - %(message)s", - "datefmt": "%H:%M:%S", - }, - "standard": { - "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", - "datefmt": "%Y-%m-%d %H:%M:%S", - }, - }, - "handlers": { - "file": { - "level": self.log_level, - "class": "logging.FileHandler", - "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", - "formatter": "standard", - }, - "console": { - "level": self.log_level, - "class": "logging.StreamHandler", - "formatter": "brief", - }, - }, - "loggers": { - "": { # root logger - "handlers": ["file", "console"], - "level": self.log_level, - "propagate": True, - }, - }, - "force": True, - } - return config - - def get_default_logging_config(self) -> dict[str, Any]: - config = { + def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: + level = log_level if log_level is not None else self.original_lvl + config: dict[str, Any] = { "version": 1, "disable_existing_loggers": False, "formatters": { @@ -209,7 +173,7 @@ def get_default_logging_config(self) -> dict[str, Any]: }, "handlers": { "console": { - "level": self.original_lvl, + "level": level, "class": "logging.StreamHandler", "formatter": "brief", }, @@ -217,10 +181,18 @@ def get_default_logging_config(self) -> dict[str, Any]: "loggers": { "": { # root logger "handlers": ["console"], - "level": self.original_lvl, + "level": level, "propagate": True, }, }, "force": True, } + if log_level is not None: + config["handlers"]["file"] = { + "level": level, + "class": "logging.FileHandler", + "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", + "formatter": "standard", + } + config["loggers"][""]["handlers"].append("file") return config diff --git a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py index 3c409b0aa8..27a5025e8b 100644 --- a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py +++ b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py @@ -1,12 +1,13 @@ -import tempfile from dataclasses import dataclass +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR + @dataclass class DebuggerConfig: log_level: str = "debug" save_engine_profile: bool = False engine_builder_monitor: bool = True - logging_dir: str = tempfile.gettempdir() + logging_dir: str = DEBUG_LOGGING_DIR profile_format: str = "perfetto" save_layer_info: bool = False diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index c55897ff45..9c1f9e18d3 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -1,10 +1,10 @@ import os -import tempfile from typing import Any, Callable, List, Optional import torch from torch.fx import passes from torch.fx.passes.pass_manager import PassManager +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR from torch_tensorrt.dynamo._settings import CompilationSettings @@ -70,7 +70,7 @@ def remove_pass_with_index(self, index: int) -> None: del self.passes[index] def insert_debug_pass_before( - self, passes: List[str], output_path_prefix: str = tempfile.gettempdir() + self, passes: List[str], output_path_prefix: str = DEBUG_LOGGING_DIR ) -> None: """Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass. @@ -96,7 +96,7 @@ def insert_debug_pass_before( self._validated = False def insert_debug_pass_after( - self, passes: List[str], output_path_prefix: str = tempfile.gettempdir() + self, passes: List[str], output_path_prefix: str = DEBUG_LOGGING_DIR ) -> None: """Insert debug passes in the PassManager pass sequence after the execution of a particular pass. diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 8d1a31564d..fc76b20141 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -131,7 +131,7 @@ def __init__( settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, - _debugger_settings: Optional[DebuggerConfig] = None, + _debugger_config: Optional[DebuggerConfig] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -161,7 +161,7 @@ def __init__( """ self.context: Any - self._debugger_settings: Optional[DebuggerConfig] = _debugger_settings + self._debugger_config: Optional[DebuggerConfig] = _debugger_config super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) @@ -199,8 +199,8 @@ def __init__( self.target_device_id ) self.profiling_enabled = ( - _debugger_settings.save_engine_profile - if _debugger_settings is not None + _debugger_config.save_engine_profile + if _debugger_config is not None else False ) self.settings = settings diff --git a/tools/debug/engine_visualization/README.md b/tools/debug/engine_visualization/README.md index 40147cb17c..90547b8ba9 100644 --- a/tools/debug/engine_visualization/README.md +++ b/tools/debug/engine_visualization/README.md @@ -1,5 +1,5 @@ ## Introduction -We use the TRT Engine Explorer (TREX) to visualize the engien graph structure. TREX is a diagnostic and profiling tool for TensorRT engine files. It allows you to inspect, benchmark, and debug TensorRT engines with ease. +We use the TRT Engine Explorer (TREX) to visualize the engine graph structure. TREX is a diagnostic and profiling tool for TensorRT engine files. It allows you to inspect, benchmark, and debug TensorRT engines with ease. ## Installation ```bash @@ -7,3 +7,5 @@ pip install git+https://github.com/NVIDIA/TensorRT.git#subdirectory=tools/experi sudo apt --yes install graphviz ``` +## Usage +The example usage can be found in `draw_engine_graph_example.py`. We use `torch_tensorrt.dynamo.debugger` to first output the engine profile info that required by TREX. Note that only when the compilation settings `use_python_runtime=False` can it produce TREX profiling. When it is saved to a folder, we call `draw_engine` on the same directory where the profile files are saved, which is in the subdirectory `engine_visualization_profile`. \ No newline at end of file diff --git a/tools/debug/engine_visualization/draw_engine_graph_example.py b/tools/debug/engine_visualization/draw_engine_graph_example.py index 490cb060b5..06aa3ae63b 100644 --- a/tools/debug/engine_visualization/draw_engine_graph_example.py +++ b/tools/debug/engine_visualization/draw_engine_graph_example.py @@ -5,32 +5,29 @@ import torch import torch_tensorrt as torch_tensorrt import torchvision.models as models +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] model = models.resnet18(pretrained=False).eval().to("cuda") exp_program = torch.export.export(model, tuple(inputs)) -enabled_precisions = {torch.float} -workspace_size = 20 << 30 -# min_block_size = 0 -use_python_runtime = False -torch_executed_ops = {} -logging_dir = "/home/profile" + with torch_tensorrt.dynamo.Debugger( "graphs", - logging_dir=logging_dir, + logging_dir=DEBUG_LOGGING_DIR, capture_fx_graph_after=["constant_fold"], save_engine_profile=True, + profile_format="trex", ): trt_gm = torch_tensorrt.dynamo.compile( exp_program, inputs=inputs, - enabled_precisions=enabled_precisions, + enabled_precisions={torch.float}, truncate_double=True, use_python_runtime=False, + min_block_size=1, ) trt_output = trt_gm(*inputs) from draw_engine_graph import draw_engine - draw_engine(os.path.join(logging_dir, "engine_visualization")) -print() + draw_engine(os.path.join(DEBUG_LOGGING_DIR, "engine_visualization_profile"))