@@ -25,10 +25,6 @@ def __init__(
2525 logging_dir : Optional [str ] = None ,
2626 ):
2727 self .debug_file_dir = tempfile .TemporaryDirectory ().name
28- if log_level != "graphs" and (capture_fx_graph_after or save_engine_profile ):
29- _LOGGER .warning (
30- "Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
31- )
3228
3329 if log_level == "debug" :
3430 self .log_level = logging .DEBUG
@@ -60,7 +56,7 @@ def __enter__(self) -> None:
6056 self .rt_level = torch .ops .tensorrt .get_logging_level ()
6157 dictConfig (self .get_config ())
6258
63- if self .log_level == GRAPH_LEVEL :
59+ if self .capture_fx_graph_before or self . capture_fx_graph_after :
6460 self .old_pre_passes , self .old_post_passes = (
6561 ATEN_PRE_LOWERING_PASSES .passes ,
6662 ATEN_POST_LOWERING_PASSES .passes ,
@@ -93,14 +89,14 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
9389
9490 dictConfig (self .get_default_config ())
9591 torch .ops .tensorrt .set_logging_level (self .rt_level )
96- if self .log_level == GRAPH_LEVEL and self .capture_fx_graph_after :
92+ if self .capture_fx_graph_before or self .capture_fx_graph_after :
9793 ATEN_PRE_LOWERING_PASSES .passes , ATEN_POST_LOWERING_PASSES .passes = (
9894 self .old_pre_passes ,
9995 self .old_post_passes ,
10096 )
10197 self .debug_file_dir = tempfile .TemporaryDirectory ().name
10298
103- def get_config (self ) -> dict [str , Any ]:
99+ def get_customized_logging_config (self ) -> dict [str , Any ]:
104100 config = {
105101 "version" : 1 ,
106102 "disable_existing_loggers" : False ,
@@ -138,7 +134,7 @@ def get_config(self) -> dict[str, Any]:
138134 }
139135 return config
140136
141- def get_default_config (self ) -> dict [str , Any ]:
137+ def get_default_logging_config (self ) -> dict [str , Any ]:
142138 config = {
143139 "version" : 1 ,
144140 "disable_existing_loggers" : False ,
0 commit comments