1
- from typing import Any , Callable , List , Optional
1
+ import tempfile
2
+ from types import new_class
3
+ from typing import Any , Callable , List , Optional , Union
2
4
3
5
import torch
4
6
from torch .fx import passes
5
7
from torch .fx .passes .pass_manager import PassManager
6
8
from torch_tensorrt .dynamo ._settings import CompilationSettings
7
9
8
10
9
- def get_draw_fx_graph_pass_lowering (
10
- idx : int , path_prefix : str , post : bool
11
+ def _generate_draw_fx_graph_pass (
12
+ output_path_prefix : str , name : str
11
13
) -> Callable [[torch .fx .GraphModule , CompilationSettings ], torch .fx .GraphModule ]:
12
- from torch_tensorrt .dynamo .lowering .passes import (
13
- post_lowering_pass_list ,
14
- pre_lowering_pass_list ,
15
- )
16
-
17
- PRE_DEBUG_NAME = {
18
- i + 1 : f"after_{ p .__name__ } " for i , p in enumerate (pre_lowering_pass_list )
19
- }
20
- PRE_DEBUG_NAME [0 ] = "exported_program"
21
-
22
- POST_DEBUG_NAME = {
23
- i + 1 : f"after_{ p .__name__ } " for i , p in enumerate (post_lowering_pass_list )
24
- }
25
- POST_DEBUG_NAME [0 ] = "after_decomposition"
26
-
27
14
def draw_fx_graph_pass (
28
15
gm : torch .fx .GraphModule , settings : CompilationSettings
29
16
) -> torch .fx .GraphModule :
30
- DEBUG_NAME = POST_DEBUG_NAME [idx ] if post else PRE_DEBUG_NAME [idx ]
31
- path = f"{ path_prefix } _{ DEBUG_NAME } .svg"
32
- g = passes .graph_drawer .FxGraphDrawer (gm , DEBUG_NAME )
17
+ path = f"{ output_path_prefix } /{ name } .svg"
18
+ g = passes .graph_drawer .FxGraphDrawer (gm , name )
33
19
with open (path , "wb" ) as f :
34
20
f .write (g .get_dot_graph ().create_svg ())
35
21
return gm
@@ -47,8 +33,9 @@ def __init__(
47
33
]
48
34
]
49
35
] = None ,
36
+ constraints : Optional [List [Callable ]] = None
50
37
):
51
- super ().__init__ (passes )
38
+ super ().__init__ (passes , constraints )
52
39
53
40
@classmethod
54
41
def build_from_passlist (
@@ -80,16 +67,48 @@ def add_pass_with_index(
80
67
def remove_pass_with_index (self , index : int ) -> None :
81
68
del self .passes [index ]
82
69
83
- def insert_debug_pass (
84
- self , index : List [int ], filename_prefix : str , post : bool = True
70
+ def insert_debug_pass_before (
71
+ self , passes : List [str ], output_path_prefix : str = tempfile . gettempdir ()
85
72
) -> None :
73
+ """Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.
74
+
75
+ Args:
76
+ passes: List of pass names to insert debug passes before
77
+ output_path_prefix: Prefix to use for generated debug files
78
+
79
+ Debug passes generate SVG visualizations of the FX graph at specified points
80
+ in the pass sequence.
81
+ """
82
+ new_pass_list = []
83
+ for ps in self .passes :
84
+ if ps .__name__ in passes :
85
+ new_pass_list .append (_generate_draw_fx_graph_pass (output_path_prefix , f"before_{ ps .__name__ } " ))
86
+ new_pass_list .append (ps )
87
+
88
+ self .passes = new_pass_list
89
+ self ._validated = False
90
+
91
+ def insert_debug_pass_after (
92
+ self , passes : List [str ], output_path_prefix : str = tempfile .gettempdir ()
93
+ ) -> None :
94
+ """Insert debug passes in the PassManager pass sequence after the execution of a particular pass.
95
+
96
+ Args:
97
+ passes: List of pass names to insert debug passes after
98
+ output_path_prefix: Prefix to use for generated debug files
99
+
100
+ Debug passes generate SVG visualizations of the FX graph at specified points
101
+ in the pass sequence.
102
+ """
103
+ new_pass_list = []
104
+ for ps in self .passes :
105
+ new_pass_list .append (ps )
106
+ if ps .__name__ in passes :
107
+ new_pass_list .append (_generate_draw_fx_graph_pass (output_path_prefix , f"after_{ ps .__name__ } " ))
86
108
87
- for i in range (len (index )):
88
109
89
- debug_pass = get_draw_fx_graph_pass_lowering (
90
- index [i ], filename_prefix , post
91
- )
92
- self .add_pass_with_index (debug_pass , index [i ] + i )
110
+ self .passes = new_pass_list
111
+ self ._validated = False
93
112
94
113
def __call__ (self , gm : Any , settings : CompilationSettings ) -> Any :
95
114
self .validate ()
0 commit comments