Skip to content

Commit 1d34057

Browse files
authored
Simplify pass manager debug system (#3530)
1 parent 0cac315 commit 1d34057

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,21 @@
1-
from typing import Any, Callable, List, Optional
1+
import tempfile
2+
from types import new_class
3+
from typing import Any, Callable, List, Optional, Union
24

35
import torch
46
from torch.fx import passes
57
from torch.fx.passes.pass_manager import PassManager
68
from torch_tensorrt.dynamo._settings import CompilationSettings
79

810

9-
def get_draw_fx_graph_pass_lowering(
10-
idx: int, path_prefix: str, post: bool
11+
def _generate_draw_fx_graph_pass(
12+
output_path_prefix: str, name: str
1113
) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]:
12-
from torch_tensorrt.dynamo.lowering.passes import (
13-
post_lowering_pass_list,
14-
pre_lowering_pass_list,
15-
)
16-
17-
PRE_DEBUG_NAME = {
18-
i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list)
19-
}
20-
PRE_DEBUG_NAME[0] = "exported_program"
21-
22-
POST_DEBUG_NAME = {
23-
i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list)
24-
}
25-
POST_DEBUG_NAME[0] = "after_decomposition"
26-
2714
def draw_fx_graph_pass(
2815
gm: torch.fx.GraphModule, settings: CompilationSettings
2916
) -> torch.fx.GraphModule:
30-
DEBUG_NAME = POST_DEBUG_NAME[idx] if post else PRE_DEBUG_NAME[idx]
31-
path = f"{path_prefix}_{DEBUG_NAME}.svg"
32-
g = passes.graph_drawer.FxGraphDrawer(gm, DEBUG_NAME)
17+
path = f"{output_path_prefix}/{name}.svg"
18+
g = passes.graph_drawer.FxGraphDrawer(gm, name)
3319
with open(path, "wb") as f:
3420
f.write(g.get_dot_graph().create_svg())
3521
return gm
@@ -47,8 +33,9 @@ def __init__(
4733
]
4834
]
4935
] = None,
36+
constraints: Optional[List[Callable]] = None
5037
):
51-
super().__init__(passes)
38+
super().__init__(passes, constraints)
5239

5340
@classmethod
5441
def build_from_passlist(
@@ -80,16 +67,48 @@ def add_pass_with_index(
8067
def remove_pass_with_index(self, index: int) -> None:
8168
del self.passes[index]
8269

83-
def insert_debug_pass(
84-
self, index: List[int], filename_prefix: str, post: bool = True
70+
def insert_debug_pass_before(
71+
self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
8572
) -> None:
73+
"""Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.
74+
75+
Args:
76+
passes: List of pass names to insert debug passes before
77+
output_path_prefix: Prefix to use for generated debug files
78+
79+
Debug passes generate SVG visualizations of the FX graph at specified points
80+
in the pass sequence.
81+
"""
82+
new_pass_list = []
83+
for ps in self.passes:
84+
if ps.__name__ in passes:
85+
new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}"))
86+
new_pass_list.append(ps)
87+
88+
self.passes = new_pass_list
89+
self._validated = False
90+
91+
def insert_debug_pass_after(
92+
self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
93+
) -> None:
94+
"""Insert debug passes in the PassManager pass sequence after the execution of a particular pass.
95+
96+
Args:
97+
passes: List of pass names to insert debug passes after
98+
output_path_prefix: Prefix to use for generated debug files
99+
100+
Debug passes generate SVG visualizations of the FX graph at specified points
101+
in the pass sequence.
102+
"""
103+
new_pass_list = []
104+
for ps in self.passes:
105+
new_pass_list.append(ps)
106+
if ps.__name__ in passes:
107+
new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}"))
86108

87-
for i in range(len(index)):
88109

89-
debug_pass = get_draw_fx_graph_pass_lowering(
90-
index[i], filename_prefix, post
91-
)
92-
self.add_pass_with_index(debug_pass, index[i] + i)
110+
self.passes = new_pass_list
111+
self._validated = False
93112

94113
def __call__(self, gm: Any, settings: CompilationSettings) -> Any:
95114
self.validate()

0 commit comments

Comments
 (0)