1
- from typing import Any , Callable , List , Optional , Sequence
1
+ from typing import Any , Callable , List , Optional
2
2
3
3
import torch
4
+ from torch .fx import passes
4
5
from torch .fx .passes .pass_manager import PassManager
5
6
from torch_tensorrt .dynamo ._settings import CompilationSettings
6
- from torch_tensorrt .dynamo .lowering .passes .draw_fx_graph import (
7
- get_draw_fx_graph_pass_post_lowering ,
8
- get_draw_fx_graph_pass_pre_lowering ,
9
- )
7
+
8
+
9
+ def get_draw_fx_graph_pass_lowering (
10
+ idx : int , path_prefix : str , post : bool
11
+ ) -> Callable [[torch .fx .GraphModule , CompilationSettings ], torch .fx .GraphModule ]:
12
+ from torch_tensorrt .dynamo .lowering .passes import (
13
+ post_lowering_pass_list ,
14
+ pre_lowering_pass_list ,
15
+ )
16
+
17
+ PRE_DEBUG_NAME = {
18
+ i + 1 : f"after_{ p .__name__ } " for i , p in enumerate (pre_lowering_pass_list )
19
+ }
20
+ PRE_DEBUG_NAME [0 ] = "exported_program"
21
+
22
+ POST_DEBUG_NAME = {
23
+ i + 1 : f"after_{ p .__name__ } " for i , p in enumerate (post_lowering_pass_list )
24
+ }
25
+ POST_DEBUG_NAME [0 ] = "after_decomposition"
26
+
27
+ def draw_fx_graph_pass (
28
+ gm : torch .fx .GraphModule , settings : CompilationSettings
29
+ ) -> torch .fx .GraphModule :
30
+ DEBUG_NAME = POST_DEBUG_NAME [idx ] if post else PRE_DEBUG_NAME [idx ]
31
+ path = f"{ path_prefix } _{ DEBUG_NAME } .svg"
32
+ g = passes .graph_drawer .FxGraphDrawer (gm , DEBUG_NAME )
33
+ with open (path , "wb" ) as f :
34
+ f .write (g .get_dot_graph ().create_svg ())
35
+ return gm
36
+
37
+ return draw_fx_graph_pass
10
38
11
39
12
40
class DynamoPassManager (PassManager ): # type: ignore[misc]
@@ -39,8 +67,7 @@ def build_from_passlist(
39
67
def add_pass_with_index (
40
68
self ,
41
69
lowering_pass : Callable [
42
- [torch .fx .GraphModule , CompilationSettings , Sequence [torch .Tensor ]],
43
- torch .fx .GraphModule ,
70
+ [torch .fx .GraphModule , CompilationSettings ], torch .fx .GraphModule
44
71
],
45
72
index : Optional [int ] = None ,
46
73
) -> None :
@@ -58,14 +85,10 @@ def insert_debug_pass(
58
85
) -> None :
59
86
60
87
for i in range (len (index )):
61
- if post :
62
- debug_pass = get_draw_fx_graph_pass_post_lowering (
63
- index [i ], filename_prefix
64
- )
65
- else :
66
- debug_pass = get_draw_fx_graph_pass_pre_lowering (
67
- index [i ], filename_prefix
68
- )
88
+
89
+ debug_pass = get_draw_fx_graph_pass_lowering (
90
+ index [i ], filename_prefix , post
91
+ )
69
92
self .add_pass_with_index (debug_pass , index [i ] + i )
70
93
71
94
def __call__ (self , gm : Any , settings : CompilationSettings ) -> Any :
0 commit comments