Skip to content

Commit 0cac315

Browse files
committed
merged the file to pass_manager
1 parent 91b4376 commit 0cac315

File tree

2 files changed

+38
-65
lines changed

2 files changed

+38
-65
lines changed

py/torch_tensorrt/dynamo/lowering/passes/draw_fx_graph.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,40 @@
1-
from typing import Any, Callable, List, Optional, Sequence
1+
from typing import Any, Callable, List, Optional
22

33
import torch
4+
from torch.fx import passes
45
from torch.fx.passes.pass_manager import PassManager
56
from torch_tensorrt.dynamo._settings import CompilationSettings
6-
from torch_tensorrt.dynamo.lowering.passes.draw_fx_graph import (
7-
get_draw_fx_graph_pass_post_lowering,
8-
get_draw_fx_graph_pass_pre_lowering,
9-
)
7+
8+
9+
def get_draw_fx_graph_pass_lowering(
10+
idx: int, path_prefix: str, post: bool
11+
) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]:
12+
from torch_tensorrt.dynamo.lowering.passes import (
13+
post_lowering_pass_list,
14+
pre_lowering_pass_list,
15+
)
16+
17+
PRE_DEBUG_NAME = {
18+
i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list)
19+
}
20+
PRE_DEBUG_NAME[0] = "exported_program"
21+
22+
POST_DEBUG_NAME = {
23+
i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list)
24+
}
25+
POST_DEBUG_NAME[0] = "after_decomposition"
26+
27+
def draw_fx_graph_pass(
28+
gm: torch.fx.GraphModule, settings: CompilationSettings
29+
) -> torch.fx.GraphModule:
30+
DEBUG_NAME = POST_DEBUG_NAME[idx] if post else PRE_DEBUG_NAME[idx]
31+
path = f"{path_prefix}_{DEBUG_NAME}.svg"
32+
g = passes.graph_drawer.FxGraphDrawer(gm, DEBUG_NAME)
33+
with open(path, "wb") as f:
34+
f.write(g.get_dot_graph().create_svg())
35+
return gm
36+
37+
return draw_fx_graph_pass
1038

1139

1240
class DynamoPassManager(PassManager): # type: ignore[misc]
@@ -39,8 +67,7 @@ def build_from_passlist(
3967
def add_pass_with_index(
4068
self,
4169
lowering_pass: Callable[
42-
[torch.fx.GraphModule, CompilationSettings, Sequence[torch.Tensor]],
43-
torch.fx.GraphModule,
70+
[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule
4471
],
4572
index: Optional[int] = None,
4673
) -> None:
@@ -58,14 +85,10 @@ def insert_debug_pass(
5885
) -> None:
5986

6087
for i in range(len(index)):
61-
if post:
62-
debug_pass = get_draw_fx_graph_pass_post_lowering(
63-
index[i], filename_prefix
64-
)
65-
else:
66-
debug_pass = get_draw_fx_graph_pass_pre_lowering(
67-
index[i], filename_prefix
68-
)
88+
89+
debug_pass = get_draw_fx_graph_pass_lowering(
90+
index[i], filename_prefix, post
91+
)
6992
self.add_pass_with_index(debug_pass, index[i] + i)
7093

7194
def __call__(self, gm: Any, settings: CompilationSettings) -> Any:

0 commit comments

Comments
 (0)