Skip to content

FX graph visualization #3528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
@@ -281,6 +281,16 @@ void TRTEngine::enable_profiling() {
exec_ctx->setProfiler(trt_engine_profiler.get());
}

void TRTEngine::set_profile_format(std::string format) {
if (format == "trex") {
this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX);
} else if (format == "perfetto") {
this->trt_engine_profiler->set_profile_format(TraceFormat::kPERFETTO);
} else {
TORCHTRT_THROW_ERROR("Invalid profile format: " + format);
}
}

std::string TRTEngine::get_engine_layer_info() {
auto inspector = cuda_engine->createEngineInspector();
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
@@ -147,6 +147,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
void enable_profiling();
void set_profile_format(std::string profile_format);
void disable_profiling();
std::string get_engine_layer_info();

31 changes: 23 additions & 8 deletions core/runtime/TRTEngineProfiler.cpp
Original file line number Diff line number Diff line change
@@ -32,25 +32,40 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector<
}
}

void TRTEngineProfiler::set_profile_format(TraceFormat format) {
this->profile_format = format;
}

void dump_trace(const std::string& path, const TRTEngineProfiler& value) {
std::stringstream out;
out << "[" << std::endl;
double ts = 0.0;
double running_time = 0.0;
for (size_t i = 0; i < value.layer_names.size(); i++) {
auto layer_name = value.layer_names[i];
auto elem = value.profile.at(layer_name);
ts += elem.time;
}
for (size_t i = 0; i < value.layer_names.size(); i++) {
auto layer_name = value.layer_names[i];
auto elem = value.profile.at(layer_name);

out << " {" << std::endl;
out << " \"name\": \"" << layer_name << "\"," << std::endl;
out << " \"ph\": \"X\"," << std::endl;
out << " \"ts\": " << ts * 1000 << "," << std::endl;
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
out << " \"tid\": 1," << std::endl;
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
out << " \"args\": {}" << std::endl;
if (value.profile_format == TraceFormat::kPERFETTO) {
out << " \"ph\": \"X\"," << std::endl;
out << " \"ts\": " << running_time * 1000 << "," << std::endl;
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
out << " \"tid\": 1," << std::endl;
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
out << " \"args\": {}" << std::endl;
} else { // kTREX
out << " \"timeMs\": " << elem.time << "," << std::endl;
out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl;
out << " \"percentage\": " << (elem.time * 100.0 / ts) << "," << std::endl;
}
out << " }," << std::endl;

ts += elem.time;
running_time += elem.time;
}
out.seekp(-2, out.cur);
out << "\n]" << std::endl;
5 changes: 4 additions & 1 deletion core/runtime/TRTEngineProfiler.h
Original file line number Diff line number Diff line change
@@ -10,12 +10,14 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

enum TraceFormat { kPERFETTO, kTREX };

struct TRTEngineProfiler : public nvinfer1::IProfiler {
struct Record {
float time{0};
int count{0};
};

void set_profile_format(TraceFormat format);
virtual void reportLayerTime(const char* layerName, float ms) noexcept;
TRTEngineProfiler(
const std::string& name,
@@ -27,6 +29,7 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler {
std::string name;
std::vector<std::string> layer_names;
std::map<std::string, Record> profile;
TraceFormat profile_format = TraceFormat::kPERFETTO;
};

} // namespace runtime
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
@@ -82,6 +82,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("__repr__", &TRTEngine::to_str)
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
.def("enable_profiling", &TRTEngine::enable_profiling)
.def("set_profile_format", &TRTEngine::set_profile_format)
.def("disable_profiling", &TRTEngine::disable_profiling)
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -19,3 +19,4 @@
from ._settings import CompilationSettings
from ._SourceIR import SourceIR
from ._tracer import trace
from .debug._Debugger import Debugger
74 changes: 55 additions & 19 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

import collections.abc
import logging
import os
import platform
import warnings
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -31,6 +32,8 @@
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
post_lowering,
@@ -42,7 +45,6 @@
get_output_metadata,
parse_graph_io,
prepare_inputs,
set_log_level,
to_torch_device,
to_torch_tensorrt_device,
)
@@ -64,7 +66,6 @@ def cross_compile_for_windows(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -138,7 +139,6 @@ def cross_compile_for_windows(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
debug (bool): Enable debuggable engine
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -185,8 +185,12 @@ def cross_compile_for_windows(
f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}"
)

if debug:
set_log_level(logger.parent, logging.DEBUG)
if kwargs.get("debug", False):
warnings.warn(
"`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality.",
DeprecationWarning,
stacklevel=2,
)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
@@ -297,7 +301,6 @@ def cross_compile_for_windows(
"enabled_precisions": (
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
),
"debug": debug,
"device": device,
"assume_dynamic_shape_support": assume_dynamic_shape_support,
"workspace_size": workspace_size,
@@ -399,7 +402,6 @@ def compile(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -475,7 +477,6 @@ def compile(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
debug (bool): Enable debuggable engine
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -518,8 +519,13 @@ def compile(
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
"""

if debug:
set_log_level(logger.parent, logging.DEBUG)
if kwargs.get("debug", False):
warnings.warn(
"`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality",
DeprecationWarning,
stacklevel=2,
)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
@@ -641,7 +647,6 @@ def compile(
"enabled_precisions": (
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
),
"debug": debug,
"device": device,
"assume_dynamic_shape_support": assume_dynamic_shape_support,
"workspace_size": workspace_size,
@@ -715,12 +720,15 @@ def compile(
return trt_gm


@fn_supports_debugger
def compile_module(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
*,
_debugger_config: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
"""Compile a traced FX module

@@ -744,7 +752,7 @@ def compile_module(

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
gm, settings.torch_executed_ops
)

dryrun_tracker.total_ops_in_graph = total_ops
@@ -796,7 +804,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
logger.info("Partitioning the graph via the fast partitioner")
partitioned_module, supported_ops = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
@@ -817,7 +824,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
logger.info("Partitioning the graph via the global partitioner")
partitioned_module, supported_ops = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
@@ -925,6 +931,35 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

trt_modules[name] = trt_module

if _debugger_config:

if _debugger_config.save_engine_profile:
if settings.use_python_runtime:
if _debugger_config.profile_format == "trex":
logger.warning(
"Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
)
trt_module.enable_profiling()
else:
path = os.path.join(
_debugger_config.logging_dir,
"engine_visualization_profile",
)
os.makedirs(path, exist_ok=True)
trt_module.enable_profiling(
profiling_results_dir=path,
profile_format=_debugger_config.profile_format,
)

if _debugger_config.save_layer_info:
with open(
os.path.join(
_debugger_config.logging_dir, "engine_layer_info.json"
),
"w",
) as f:
f.write(trt_module.get_layer_info())

# Parse the graph I/O and store it in dryrun tracker
parse_graph_io(gm, dryrun_tracker)

@@ -952,7 +987,6 @@ def convert_exported_program_to_serialized_trt_engine(
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
debug: bool = _defaults.DEBUG,
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
workspace_size: int = _defaults.WORKSPACE_SIZE,
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
@@ -1014,7 +1048,6 @@ def convert_exported_program_to_serialized_trt_engine(
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
debug (bool): Whether to print out verbose debugging information
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
min_block_size (int): Minimum number of operators per TRT-Engine Block
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
@@ -1054,8 +1087,12 @@ def convert_exported_program_to_serialized_trt_engine(
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
if debug:
set_log_level(logger.parent, logging.DEBUG)
if kwargs.get("debug", False):
warnings.warn(
"`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality.",
DeprecationWarning,
stacklevel=2,
)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
@@ -1139,7 +1176,6 @@ def convert_exported_program_to_serialized_trt_engine(
compilation_options = {
"assume_dynamic_shape_support": assume_dynamic_shape_support,
"enabled_precisions": enabled_precisions,
"debug": debug,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops,
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@
from torch_tensorrt._enums import EngineCapability, dtype

ENABLED_PRECISIONS = {dtype.f32}
DEBUG = False
DEVICE = None
DISABLE_TF32 = False
ASSUME_DYNAMIC_SHAPE_SUPPORT = False
@@ -50,6 +49,7 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
DEBUG_LOGGING_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt/debug_logs")


def default_device() -> Device:
7 changes: 0 additions & 7 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,6 @@
check_module_output,
get_model_device,
get_torch_inputs,
set_log_level,
to_torch_device,
to_torch_tensorrt_device,
)
@@ -72,7 +71,6 @@ def construct_refit_mapping(
interpreter = TRTInterpreter(
module,
inputs,
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
output_dtypes=output_dtypes,
compilation_settings=settings,
)
@@ -266,9 +264,6 @@ def refit_module_weights(
not settings.immutable_weights
), "Refitting is not enabled. Please recompile the engine with immutable_weights=False."

if settings.debug:
set_log_level(logger.parent, logging.DEBUG)

device = to_torch_tensorrt_device(settings.device)
if arg_inputs:
if not isinstance(arg_inputs, collections.abc.Sequence):
@@ -304,7 +299,6 @@ def refit_module_weights(
try:
new_partitioned_module, supported_ops = partitioning.fast_partition(
new_gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
@@ -320,7 +314,6 @@ def refit_module_weights(
if not settings.use_fast_partitioner:
new_partitioned_module, supported_ops = partitioning.global_partition(
new_gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
Loading