Skip to content

Commit f1654f5

Browse files
committed
fix bugs and clean codes
1 parent bcb50b2 commit f1654f5

File tree

4 files changed

+89
-62
lines changed

4 files changed

+89
-62
lines changed

examples/hierarchical_partitioner_example.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import torch
22
import torch.nn as nn
33
import torch_tensorrt
4-
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
5-
DYNAMO_ATEN_CONVERTERS,
6-
)
74
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
85
DYNAMO_CONVERTERS as CONVERTERS,
96
)
@@ -15,6 +12,7 @@
1512
from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
1613
hierarchical_adjacency_partition,
1714
)
15+
from torchvision import models
1816

1917

2018
class SimpleModel(nn.Module):
@@ -50,18 +48,18 @@ def main():
5048

5149
gm = exported_program.module()
5250

53-
print(gm.graph)
51+
print(gm)
5452

5553
original_output = model(example_input)
5654

57-
# Partition the model using the adjacency partitioner
55+
# Partition the model using the adjacency partitioner, compared with below
5856
# partitioned_model, op_support = partition(
5957
# gm,
6058
# verbose=True,
6159
# min_block_size=1,
62-
# torch_executed_ops=[
63-
# torch.ops.aten.relu.default,
64-
# ],
60+
# torch_executed_ops={
61+
# "torch.ops.aten.relu.default",
62+
# },
6563
# )
6664

6765
partitioned_model, op_support = hierarchical_adjacency_partition(
@@ -71,21 +69,18 @@ def main():
7169
backend_priority=["inductor", "tensorrt"],
7270
backend_support_map={
7371
"inductor": {
74-
# operator.getitem,
75-
torch.ops.aten.conv2d.default,
76-
torch.ops.aten.convolution.default,
72+
"torch.ops.aten.convolution.default",
7773
},
78-
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
74+
"tensorrt": CONVERTERS.keys(),
75+
},
76+
torch_executed_ops={
77+
"torch.ops.aten._native_batch_norm_legit_no_training.default"
7978
},
80-
torch_executed_ops=[
81-
torch.ops.aten._native_batch_norm_legit_no_training.default
82-
],
8379
require_full_compilation=False,
84-
skip_fusion=False,
80+
skip_fusion=True,
8581
)
8682

87-
print("\nPartitioned Model Structure:")
88-
print(partitioned_model)
83+
print("\nPartitioned Model Structure:\n", partitioned_model)
8984

9085
print("0. Original_output:", original_output)
9186

@@ -98,8 +93,15 @@ def main():
9893
)
9994

10095
compiled_model = torch_tensorrt.compile(
101-
model, inputs=[example_input], min_block_size=1
96+
model,
97+
inputs=[example_input],
98+
min_block_size=1,
99+
torch_executed_ops={
100+
"torch.ops.aten._native_batch_norm_legit_no_training.default"
101+
},
102102
)
103+
print("\nCompiled Model Structure:\n", compiled_model)
104+
103105
with torch.no_grad():
104106
compiled_output = compiled_model(example_input)
105107
print("2. Compiled_output:", compiled_output)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,17 @@
44
import logging
55
import platform
66
import warnings
7-
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
7+
from typing import (
8+
Any,
9+
Callable,
10+
Collection,
11+
List,
12+
Optional,
13+
Sequence,
14+
Set,
15+
Tuple,
16+
Union,
17+
)
818

919
import torch
1020
from torch.export import ExportedProgram
@@ -28,9 +38,6 @@
2838
interpret_module_to_result,
2939
repair_double_inputs,
3040
)
31-
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
32-
DYNAMO_ATEN_CONVERTERS,
33-
)
3441
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3542
DYNAMO_CONVERTERS as CONVERTERS,
3643
)
@@ -792,16 +799,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
792799
)
793800

794801
############ TODO: testing only ############
795-
use_hierarchical_partitioner = False
802+
use_hierarchical_partitioner = True
796803
backend_priority = ["inductor", "tensorrt"]
797804
backend_support_map = {
798805
"inductor": {
799-
# operator.getitem,
800-
torch.ops.aten.conv2d.default,
801-
torch.ops.aten.convolution.default,
806+
"torch.ops.aten.convolution.default",
802807
},
803-
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
808+
"tensorrt": CONVERTERS.keys(),
804809
}
810+
skip_fusion = True
805811
#############################################
806812
# Partition module into components that can be TRT-accelerated
807813
fast_partitioner_failed = False
@@ -819,7 +825,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
819825
min_block_size=settings.min_block_size,
820826
torch_executed_ops=settings.torch_executed_ops,
821827
require_full_compilation=settings.require_full_compilation,
822-
skip_fusion=(num_supported_ops == total_ops),
828+
skip_fusion=skip_fusion,
823829
backend_priority=backend_priority,
824830
backend_support_map=backend_support_map,
825831
)
@@ -953,19 +959,17 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
953959
if "_run_on_acc_inductor" in name:
954960
sub_inputs = []
955961
for input in submodule_inputs:
956-
sub_input = (
957-
torch.randn(input.shape)
958-
.to(dtype.to(input.dtype, t=torch.dtype))
959-
.cuda()
960-
)
962+
sub_input = input.torch_tensor.to(
963+
dtype.to(input.dtype, t=torch.dtype)
964+
).cuda()
961965
sub_inputs.append(sub_input)
962966

963967
compiled_func = torch._inductor.compile(
964968
submodule,
965969
sub_inputs,
966970
)
967971
# Wrap the compiled function to be a torch.nn.Module
968-
compiled_submodule = FunctionWrapper(compiled_func)
972+
compiled_submodule = InductorModule(compiled_func)
969973

970974
elif "_run_on_acc_tensorrt" in name:
971975
compiled_submodule = convert_module(
@@ -1345,10 +1349,12 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any:
13451349
return replace_execute_engine_no_op_node(exp_program)
13461350

13471351

1348-
class FunctionWrapper(torch.nn.Module):
1349-
def __init__(self, func):
1352+
class InductorModule(torch.nn.Module): # type: ignore[misc]
1353+
"""Wrapper module for inductor compiled function."""
1354+
1355+
def __init__(self, func: Callable[..., Any]) -> None:
13501356
super().__init__()
13511357
self.func = func
13521358

1353-
def forward(self, *args, **kwargs):
1359+
def forward(self, *args: Any, **kwargs: Any) -> Any:
13541360
return self.func(*args, **kwargs)

py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import logging
22
from dataclasses import dataclass
3-
from typing import Collection, Dict, List, Optional, Set, Tuple
3+
from typing import Collection, Dict, List, Optional, Tuple
44

55
import torch
66
import torch.fx.passes.operator_support as ops
7-
from torch._ops import OpOverload
87
from torch.fx._compatibility import compatibility
9-
from torch.fx.node import Target, _get_qualified_name
8+
from torch.fx.node import Target
109
from torch.fx.passes.splitter_base import (
1110
_SplitterBase,
1211
_SplitterSettingBase,
@@ -24,12 +23,15 @@
2423
REQUIRE_FULL_COMPILATION,
2524
)
2625
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
27-
DYNAMO_ATEN_CONVERTERS,
26+
DYNAMO_CONVERTERS,
2827
ConverterRegistry,
2928
)
3029

3130
logger = logging.getLogger(__name__)
3231

32+
NON_COMPUTE_NODES = {"torch.ops.aten.view", "_operator.getitem"}
33+
NON_ACC_BACKEND_NAME = "None"
34+
3335

3436
@compatibility(is_backward_compatible=False)
3537
@dataclass
@@ -45,7 +47,7 @@ class BackendOpSupportTester(ops.OperatorSupportBase): # type: ignore
4547

4648
def __init__(
4749
self,
48-
backend_support_map: Dict[str, Set[OpOverload]],
50+
backend_support_map: Dict[str, Collection[Target]],
4951
backend_priority: List[str],
5052
torch_executed_ops: Collection[Target] = set(),
5153
) -> None:
@@ -62,12 +64,14 @@ def __init__(
6264

6365
def is_node_supported(
6466
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
65-
) -> Tuple[bool, Optional[str]]:
67+
) -> Tuple[bool, str]:
6668
node_name = ConverterRegistry.qualified_name_or_str(node.target)
6769

6870
for i, backend_name in enumerate(self.backend_priority):
6971
supported_ops = self.backend_support_map.get(backend_name, set())
70-
supported_ops = {_get_qualified_name(op) for op in supported_ops}
72+
supported_ops = {
73+
ConverterRegistry.qualified_name_or_str(op) for op in supported_ops
74+
}
7175

7276
if (
7377
(node_name in supported_ops or node.op == "get_attr")
@@ -89,7 +93,7 @@ def is_node_supported(
8993
else:
9094
self.unsupported_operators[node_name] += 1
9195

92-
return False, None
96+
return False, NON_ACC_BACKEND_NAME
9397

9498
def print_support_overview(self, num_acc_subgraphs: Optional[int] = None) -> None:
9599
if num_acc_subgraphs is not None:
@@ -137,7 +141,7 @@ def __init__(
137141
self,
138142
module: torch.fx.GraphModule,
139143
operator_support: ops.OperatorSupportBase,
140-
backend_support_map: Dict[str, Set[Target]],
144+
backend_support_map: Dict[str, Collection[Target]],
141145
backend_priority: List[str],
142146
allowed_single_node_partition_ops: Optional[Collection[str]] = None,
143147
min_block_size: int = MIN_BLOCK_SIZE,
@@ -488,15 +492,24 @@ def reduce_acc_nodes_non_tensor_output(self):
488492

489493
def __call__(self) -> NodeSet:
490494
submodules = dict(self.module.named_modules())
495+
backend = NON_ACC_BACKEND_NAME
491496
for n in self.module.graph.nodes:
492-
n.backend = "None"
497+
# Group non-compute nodes with previous compute nodes
498+
if ConverterRegistry.qualified_name_or_str(n.target) in NON_COMPUTE_NODES:
499+
n.backend = backend
500+
if backend != NON_ACC_BACKEND_NAME:
501+
self.acc_nodes.add(n)
502+
continue
503+
493504
if n.op in CALLABLE_NODE_OPS:
494505
is_supported, backend = self.operator_support.is_node_supported(
495506
submodules, n
496507
)
497508
if is_supported:
498509
n.backend = backend
499510
self.acc_nodes.add(n)
511+
else:
512+
n.backend = NON_ACC_BACKEND_NAME
500513

501514
if not self.allow_non_tensor:
502515
self.reduce_acc_nodes_non_tensor_input()
@@ -515,7 +528,7 @@ def hierarchical_adjacency_partition(
515528
verbose: bool = DEBUG,
516529
min_block_size: int = MIN_BLOCK_SIZE,
517530
torch_executed_ops: Collection[Target] = set(),
518-
backend_support_map: Optional[Dict[str, Set[OpOverload]]] = None,
531+
backend_support_map: Optional[Dict[str, Collection[Target]]] = None,
519532
backend_priority: Optional[List[str]] = None,
520533
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
521534
skip_fusion: bool = False,
@@ -542,7 +555,7 @@ def hierarchical_adjacency_partition(
542555
# Default backend support map if none provided
543556
if backend_support_map is None:
544557
backend_support_map = {
545-
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
558+
"tensorrt": DYNAMO_CONVERTERS.keys(),
546559
"inductor": set(),
547560
}
548561

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,43 +84,49 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
8484
module_inputs = [
8585
node for node in module.graph.nodes if node.op == "placeholder"
8686
]
87-
for input in module_inputs:
88-
if input.meta:
89-
if "val" in input.meta:
90-
input_meta = input.meta["val"]
87+
for input_node in module_inputs:
88+
if input_node.meta:
89+
if "val" in input_node.meta:
90+
input_meta = input_node.meta["val"]
91+
92+
if isinstance(input_meta, Sequence):
93+
input_meta = input_meta[0]
94+
9195
if isinstance(input_meta, (FakeTensor, torch.Tensor)):
9296
input_shape = input_meta.size()
9397
torchtrt_inputs.append(
94-
get_input(input_shape, input_meta.dtype, name=input.name)
98+
get_input(
99+
input_shape, input_meta.dtype, name=input_node.name
100+
)
95101
)
96102
elif isinstance(input_meta, torch.SymInt):
97103
# Assuming sym_integers | shape inputs always have torch.int64 dtype
98104
torchtrt_inputs.append(
99105
get_input(
100106
[input_meta],
101107
torch.int64,
102-
name=input.name,
108+
name=input_node.name,
103109
is_shape_tensor=True,
104110
)
105111
)
106112
else:
107113
raise ValueError(
108-
f"The meta val for input node {input.target} is of type : {type(input_meta)}. Supported types: torch.Tensor|FakeTensor|torch.SymInt"
114+
f"The meta val for input node {input_node.target} is of type : {type(input_meta)}. Supported types: torch.Tensor|FakeTensor|torch.SymInt"
109115
)
110116

111-
elif "tensor_meta" in input.meta:
112-
input_meta = input.meta["tensor_meta"]
117+
elif "tensor_meta" in input_node.meta:
118+
input_meta = input_node.meta["tensor_meta"]
113119
input_shape = input_meta.shape
114120
torchtrt_inputs.append(
115-
get_input(input_shape, input_meta.dtype, name=input.name)
121+
get_input(input_shape, input_meta.dtype, name=input_node.name)
116122
)
117123
else:
118124
raise AssertionError(
119-
f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly"
125+
f"Input {input_node.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly"
120126
)
121127
else:
122128
raise AssertionError(
123-
f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly"
129+
f"Input {input_node.name} does not contain metadata. Please ensure you have exported the graph correctly"
124130
)
125131

126132
return torchtrt_inputs

0 commit comments

Comments
 (0)