-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Hierarchical Partitioner to support multi-backends #3539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py 2025-05-29 22:05:48.305564+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py 2025-05-29 22:06:09.728061+00:00
@@ -6,28 +6,31 @@
post_lowering,
pre_export_lowering,
)
import torchvision.models as models
from torch_tensorrt.dynamo import CompilationSettings
-from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import hierarchical_partition
+from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
+ hierarchical_partition,
+)
from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition
+
# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition
import operator
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
- DYNAMO_ATEN_CONVERTERS
+ DYNAMO_ATEN_CONVERTERS,
)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)
-
+
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.conv2(x)
@@ -39,25 +42,22 @@
def main():
# Create model
model = SimpleModel().cuda()
# model = models.efficientnet_b0(pretrained=True).cuda()
model = model.eval()
-
+
# Create example input
example_input = torch.randn(1, 3, 224, 224).cuda()
-
+
exported_program = torch.export.export(model, (example_input,))
exported_program = pre_export_lowering(exported_program)
- exported_program = exported_program.run_decompositions(
- get_decompositions()
- )
+ exported_program = exported_program.run_decompositions(get_decompositions())
gm = exported_program.module()
-
+
print(gm.graph)
-
-
+
# Partition the model using the adjacency partitioner
# partitioned_model, op_support = partition(
# gm,
# verbose=True,
# min_block_size=1,
@@ -68,16 +68,16 @@
partitioned_model, op_support = hierarchical_partition(
gm,
verbose=True,
min_block_size=1,
- backend_priority=["mlir", "tensorrt"], #, "inductor"],
+ backend_priority=["mlir", "tensorrt"], # , "inductor"],
backend_support_map={
"mlir": {
# operator.getitem,
- torch.ops.aten.conv2d.default,
- torch.ops.aten.convolution.default,
+ torch.ops.aten.conv2d.default,
+ torch.ops.aten.convolution.default,
},
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
# "inductor": {
# torch.ops.aten.relu.default,
# },
@@ -86,18 +86,20 @@
torch.ops.aten._native_batch_norm_legit_no_training.default
],
require_full_compilation=False,
skip_fusion=False,
)
-
+
print("\nPartitioned Model Structure:")
print(partitioned_model)
with torch.no_grad():
output = partitioned_model(example_input)
print("Partitioned output:", output)
- print("Partitioned output == original output:", torch.allclose(model(example_input), output, 1e-2, 1e-2))
-
+ print(
+ "Partitioned output == original output:",
+ torch.allclose(model(example_input), output, 1e-2, 1e-2),
+ )
if __name__ == "__main__":
main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py 2025-05-29 22:05:48.319565+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py 2025-05-29 22:06:11.554604+00:00
@@ -3,10 +3,11 @@
from collections import defaultdict
import torch
import torch.fx.passes.operator_support as ops
from torch.fx.node import Target, Node
from torch.fx.graph_module import GraphModule
+
# from torch.fx.passes.splitter_base import (
# FxNetAccFusionsFinder,
# FxNetAccNodesFinder,
# Subgraph,
# _SplitterBase,
@@ -25,11 +26,11 @@
MIN_BLOCK_SIZE,
REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
- DYNAMO_ATEN_CONVERTERS
+ DYNAMO_ATEN_CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
from torch._ops import OpOverload
from torch.fx.node import _get_qualified_name
import torch_tensorrt.dynamo.partitioning._adjacency_partitioner as _adjacency_partitioner
@@ -38,31 +39,36 @@
class BackendOpSupportTester(ops.OperatorSupportBase): # type: ignore
"""Class to determine whether operators are supported by specific backends"""
- def __init__(self, backend_support_map: Dict[str, Set[OpOverload]], backend_priority: List[str], torch_executed_ops: Collection[Target] = set()) -> None:
+ def __init__(
+ self,
+ backend_support_map: Dict[str, Set[OpOverload]],
+ backend_priority: List[str],
+ torch_executed_ops: Collection[Target] = set(),
+ ) -> None:
super().__init__()
# Initialize sets of supported/unsupported operators
self.supported_operators: Dict[str, int] = {}
self.unsupported_operators: Dict[str, int] = {}
self.torch_executed_ops = torch_executed_ops
# Map of backend names to sets of supported operators
self.backend_support_map = backend_support_map
# Ordered list of backend names, from highest to lowest priority
self.backend_priority = backend_priority
-
+
def is_node_supported(
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
) -> Tuple[bool, Optional[str]]:
node_name = ConverterRegistry.qualified_name_or_str(node.target)
-
+
for i, backend_name in enumerate(self.backend_priority):
supported_ops = self.backend_support_map.get(backend_name, set())
supported_ops = set(_get_qualified_name(op) for op in supported_ops)
-
+
if (
(node_name in supported_ops or node.op == "get_attr")
and node_name not in self.torch_executed_ops
and node.target not in self.torch_executed_ops
):
@@ -106,11 +112,11 @@
logger.debug("\nAll Nodes Supported\n")
class HierarchicalTRTPartitioner(_SplitterBase): # type: ignore
"""Hierarchical partitioner to split an FX graph into subgraphs based on backend priority
-
+
This partitioner extends the TRTPartitioner of adjacency_partitioner with backend priority awareness,
allowing different parts of the model to be executed on different backends based on
operator support and priority ordering.
Args:
@@ -155,11 +161,11 @@
# Get all accelerated nodes based on operator support conditions
self.acc_nodes = FxNetAccNodesFinder(
self.module, self.operator_support, self.settings.allow_non_tensor
)()
-
+
if self.settings.skip_fusion:
self.fusions = {}
else:
self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))()
@@ -200,11 +206,13 @@
logger.debug(
"Eliminating acc subgraph because it's smaller than the threshold: "
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
)
# if the last subgraph result[-1] is non-acc or has the same backend, merge the current subgraph into it
- if result and (not result[-1].is_acc or result[-1].backend == subgraph.backend):
+ if result and (
+ not result[-1].is_acc or result[-1].backend == subgraph.backend
+ ):
result[-1].nodes.extend(subgraph.nodes)
else:
# if the last subgraph result[-1] has different backends, then append the current subgraph as non-acc
subgraph.is_acc = False
subgraph.backend = "None"
@@ -265,19 +273,19 @@
"""Generates starter nodes for partitioning + segmentation"""
# Starter accelerated nodes are all callable accelerated ops
starter_acc_nodes = {
node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS
}
-
+
# Started non-accelerated nodes are the rest of the callable nodes
starter_non_acc_nodes = {
node
for node in self.module.graph.nodes
if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS)
}
return starter_non_acc_nodes, starter_acc_nodes
-
+
def put_nodes_into_subgraphs(self) -> list[Subgraph]:
# We start graph traversal from leaf nodes
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set()
@@ -301,23 +309,38 @@
if node is None:
if not current_subgraph_nodes:
raise FxNetSplitterInternalError("Subgraph can't be empty")
print(222222222222, current_subgraph_nodes, acc_subgraph)
subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+ Subgraph(
+ is_acc=acc_subgraph,
+ nodes=current_subgraph_nodes,
+ backend=(
+ current_subgraph_nodes[-1].backend
+ if acc_subgraph
+ else "None"
+ ),
+ )
)
acc_subgraph = not acc_subgraph
current_subgraph_nodes = []
continue
# If the backend changed, then it's time to start a new subgraph
- if current_subgraph_nodes and current_subgraph_nodes[-1].backend != node.backend:
+ if (
+ current_subgraph_nodes
+ and current_subgraph_nodes[-1].backend != node.backend
+ ):
if not current_subgraph_nodes:
raise FxNetSplitterInternalError("Subgraph can't be empty")
print(333333333333, current_subgraph_nodes, acc_subgraph)
subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend)
+ Subgraph(
+ is_acc=acc_subgraph,
+ nodes=current_subgraph_nodes,
+ backend=current_subgraph_nodes[-1].backend,
+ )
)
current_subgraph_nodes = []
continue
current_nodes.remove(node)
@@ -343,11 +366,17 @@
current_cpu_nodes.add(user)
# Check if the last subgraph was not created
if current_subgraph_nodes:
subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+ Subgraph(
+ is_acc=acc_subgraph,
+ nodes=current_subgraph_nodes,
+ backend=(
+ current_subgraph_nodes[-1].backend if acc_subgraph else "None"
+ ),
+ )
)
if not subgraphs:
raise FxNetSplitterInternalError("Couldn't create subgraphs")
print(666666666666, subgraphs)
@@ -384,22 +413,22 @@
"""
# Ensure graph is clean prior to partitioning
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
-
+
# Default backend support map if none provided
if backend_support_map is None:
backend_support_map = {
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
"inductor": set(),
}
# Default backend priority if none provided
if backend_priority is None:
backend_priority = ["tensorrt", "inductor"]
-
+
# Construct BackendOpSupportTester
supported_ops = BackendOpSupportTester(
backend_support_map=backend_support_map,
backend_priority=backend_priority,
torch_executed_ops=torch_executed_ops,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py 2025-05-29 22:05:48.320565+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py 2025-05-29 22:06:11.987143+00:00
@@ -187,11 +187,13 @@
def __call__(self) -> NodeSet:
submodules = dict(self.module.named_modules())
for n in self.module.graph.nodes:
n.backend = "None"
if n.op in CALLABLE_NODE_OPS:
- is_supported, backend = self.operator_support.is_node_supported(submodules, n)
+ is_supported, backend = self.operator_support.is_node_supported(
+ submodules, n
+ )
if is_supported:
n.backend = backend
self.acc_nodes.add(n)
if not self.allow_non_tensor:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py 2025-05-29 22:11:19.174161+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py 2025-05-29 22:11:40.163822+00:00
@@ -6,28 +6,31 @@
post_lowering,
pre_export_lowering,
)
import torchvision.models as models
from torch_tensorrt.dynamo import CompilationSettings
-from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import hierarchical_partition
+from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
+ hierarchical_partition,
+)
from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition
+
# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition
import operator
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
- DYNAMO_ATEN_CONVERTERS
+ DYNAMO_ATEN_CONVERTERS,
)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)
-
+
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.conv2(x)
@@ -39,25 +42,22 @@
def main():
# Create model
model = SimpleModel().cuda()
# model = models.efficientnet_b0(pretrained=True).cuda()
model = model.eval()
-
+
# Create example input
example_input = torch.randn(1, 3, 224, 224).cuda()
-
+
exported_program = torch.export.export(model, (example_input,))
exported_program = pre_export_lowering(exported_program)
- exported_program = exported_program.run_decompositions(
- get_decompositions()
- )
+ exported_program = exported_program.run_decompositions(get_decompositions())
gm = exported_program.module()
-
+
print(gm.graph)
-
-
+
# Partition the model using the adjacency partitioner
# partitioned_model, op_support = partition(
# gm,
# verbose=True,
# min_block_size=1,
@@ -68,16 +68,16 @@
partitioned_model, op_support = hierarchical_partition(
gm,
verbose=True,
min_block_size=1,
- backend_priority=["mlir", "tensorrt"], #, "inductor"],
+ backend_priority=["mlir", "tensorrt"], # , "inductor"],
backend_support_map={
"mlir": {
# operator.getitem,
- torch.ops.aten.conv2d.default,
- torch.ops.aten.convolution.default,
+ torch.ops.aten.conv2d.default,
+ torch.ops.aten.convolution.default,
},
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
# "inductor": {
# torch.ops.aten.relu.default,
# },
@@ -86,18 +86,20 @@
torch.ops.aten._native_batch_norm_legit_no_training.default
],
require_full_compilation=False,
skip_fusion=False,
)
-
+
print("\nPartitioned Model Structure:")
print(partitioned_model)
with torch.no_grad():
output = partitioned_model(example_input)
print("Partitioned output:", output)
- print("Partitioned output == original output:", torch.allclose(model(example_input), output, 1e-2, 1e-2))
-
+ print(
+ "Partitioned output == original output:",
+ torch.allclose(model(example_input), output, 1e-2, 1e-2),
+ )
if __name__ == "__main__":
main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py 2025-05-29 22:11:19.189162+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py 2025-05-29 22:11:41.944938+00:00
@@ -3,10 +3,11 @@
from collections import defaultdict
import torch
import torch.fx.passes.operator_support as ops
from torch.fx.node import Target, Node
from torch.fx.graph_module import GraphModule
+
# from torch.fx.passes.splitter_base import (
# FxNetAccFusionsFinder,
# FxNetAccNodesFinder,
# Subgraph,
# _SplitterBase,
@@ -25,11 +26,11 @@
MIN_BLOCK_SIZE,
REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
- DYNAMO_ATEN_CONVERTERS
+ DYNAMO_ATEN_CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
from torch._ops import OpOverload
from torch.fx.node import _get_qualified_name
import torch_tensorrt.dynamo.partitioning._adjacency_partitioner as _adjacency_partitioner
@@ -38,31 +39,36 @@
class BackendOpSupportTester(ops.OperatorSupportBase): # type: ignore
"""Class to determine whether operators are supported by specific backends"""
- def __init__(self, backend_support_map: Dict[str, Set[OpOverload]], backend_priority: List[str], torch_executed_ops: Collection[Target] = set()) -> None:
+ def __init__(
+ self,
+ backend_support_map: Dict[str, Set[OpOverload]],
+ backend_priority: List[str],
+ torch_executed_ops: Collection[Target] = set(),
+ ) -> None:
super().__init__()
# Initialize sets of supported/unsupported operators
self.supported_operators: Dict[str, int] = {}
self.unsupported_operators: Dict[str, int] = {}
self.torch_executed_ops = torch_executed_ops
# Map of backend names to sets of supported operators
self.backend_support_map = backend_support_map
# Ordered list of backend names, from highest to lowest priority
self.backend_priority = backend_priority
-
+
def is_node_supported(
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
) -> Tuple[bool, Optional[str]]:
node_name = ConverterRegistry.qualified_name_or_str(node.target)
-
+
for i, backend_name in enumerate(self.backend_priority):
supported_ops = self.backend_support_map.get(backend_name, set())
supported_ops = set(_get_qualified_name(op) for op in supported_ops)
-
+
if (
(node_name in supported_ops or node.op == "get_attr")
and node_name not in self.torch_executed_ops
and node.target not in self.torch_executed_ops
):
@@ -106,11 +112,11 @@
logger.debug("\nAll Nodes Supported\n")
class HierarchicalTRTPartitioner(_SplitterBase): # type: ignore
"""Hierarchical partitioner to split an FX graph into subgraphs based on backend priority
-
+
This partitioner extends the TRTPartitioner of adjacency_partitioner with backend priority awareness,
allowing different parts of the model to be executed on different backends based on
operator support and priority ordering.
Args:
@@ -155,11 +161,11 @@
# Get all accelerated nodes based on operator support conditions
self.acc_nodes = FxNetAccNodesFinder(
self.module, self.operator_support, self.settings.allow_non_tensor
)()
-
+
if self.settings.skip_fusion:
self.fusions = {}
else:
self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))()
@@ -200,11 +206,13 @@
logger.debug(
"Eliminating acc subgraph because it's smaller than the threshold: "
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
)
# if the last subgraph result[-1] is non-acc or has the same backend, merge the current subgraph into it
- if result and (not result[-1].is_acc or result[-1].backend == subgraph.backend):
+ if result and (
+ not result[-1].is_acc or result[-1].backend == subgraph.backend
+ ):
result[-1].nodes.extend(subgraph.nodes)
else:
# if the last subgraph result[-1] has different backends, then append the current subgraph as non-acc
subgraph.is_acc = False
subgraph.backend = "None"
@@ -265,19 +273,19 @@
"""Generates starter nodes for partitioning + segmentation"""
# Starter accelerated nodes are all callable accelerated ops
starter_acc_nodes = {
node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS
}
-
+
# Started non-accelerated nodes are the rest of the callable nodes
starter_non_acc_nodes = {
node
for node in self.module.graph.nodes
if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS)
}
return starter_non_acc_nodes, starter_acc_nodes
-
+
def put_nodes_into_subgraphs(self) -> list[Subgraph]:
# We start graph traversal from leaf nodes
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set()
@@ -301,23 +309,38 @@
if node is None:
if not current_subgraph_nodes:
raise FxNetSplitterInternalError("Subgraph can't be empty")
print(222222222222, current_subgraph_nodes, acc_subgraph)
subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+ Subgraph(
+ is_acc=acc_subgraph,
+ nodes=current_subgraph_nodes,
+ backend=(
+ current_subgraph_nodes[-1].backend
+ if acc_subgraph
+ else "None"
+ ),
+ )
)
acc_subgraph = not acc_subgraph
current_subgraph_nodes = []
continue
# If the backend changed, then it's time to start a new subgraph
- if current_subgraph_nodes and current_subgraph_nodes[-1].backend != node.backend:
+ if (
+ current_subgraph_nodes
+ and current_subgraph_nodes[-1].backend != node.backend
+ ):
if not current_subgraph_nodes:
raise FxNetSplitterInternalError("Subgraph can't be empty")
print(333333333333, current_subgraph_nodes, acc_subgraph)
subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend)
+ Subgraph(
+ is_acc=acc_subgraph,
+ nodes=current_subgraph_nodes,
+ backend=current_subgraph_nodes[-1].backend,
+ )
)
current_subgraph_nodes = []
continue
current_nodes.remove(node)
@@ -343,11 +366,17 @@
current_cpu_nodes.add(user)
# Check if the last subgraph was not created
if current_subgraph_nodes:
subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+ Subgraph(
+ is_acc=acc_subgraph,
+ nodes=current_subgraph_nodes,
+ backend=(
+ current_subgraph_nodes[-1].backend if acc_subgraph else "None"
+ ),
+ )
)
if not subgraphs:
raise FxNetSplitterInternalError("Couldn't create subgraphs")
print(666666666666, subgraphs)
@@ -384,22 +413,22 @@
"""
# Ensure graph is clean prior to partitioning
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
-
+
# Default backend support map if none provided
if backend_support_map is None:
backend_support_map = {
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
"inductor": set(),
}
# Default backend priority if none provided
if backend_priority is None:
backend_priority = ["tensorrt", "inductor"]
-
+
# Construct BackendOpSupportTester
supported_ops = BackendOpSupportTester(
backend_support_map=backend_support_map,
backend_priority=backend_priority,
torch_executed_ops=torch_executed_ops,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py 2025-05-29 22:11:19.189162+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py 2025-05-29 22:11:42.239192+00:00
@@ -187,11 +187,13 @@
def __call__(self) -> NodeSet:
submodules = dict(self.module.named_modules())
for n in self.module.graph.nodes:
n.backend = "None"
if n.op in CALLABLE_NODE_OPS:
- is_supported, backend = self.operator_support.is_node_supported(submodules, n)
+ is_supported, backend = self.operator_support.is_node_supported(
+ submodules, n
+ )
if is_supported:
n.backend = backend
self.acc_nodes.add(n)
if not self.allow_non_tensor:
1872b7c
to
79d2616
Compare
79d2616
to
d97f668
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering how this feature will be made available, the min_block_size and torch_executed_ops need to be re-thought or deprecated as the min_block_size can apply to all backends and torch_executed_ops
will be replaced by backend_support_map
torch.ops.aten._native_batch_norm_legit_no_training.default | ||
], | ||
require_full_compilation=False, | ||
skip_fusion=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_fusion=False slows down the partitioning a lot. Can you check if it's really needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the min_block_size and torch_executed_ops need to be re-thought or deprecated as the min_block_size can apply to all backends
My understanding is that, if the num of ops of a GM is less than min_block_size
, no matter it's TRT or other backend, the GM would not be compiled. Do you mean it shouldn't apply to other backends?
torch_executed_ops
will be replaced bybackend_support_map
My understanding is that torch execute
is not considered as a backend because it doesn't need any compilation, it just runs ops in eager mode. So, if an op was in torch_executed_ops
, it would ignore backend_support_map
and run in torch eager anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_fusion=False slows down the partitioning a lot. Can you check if it's really needed ?
Since adjacency partitioner uses this flag, I just keep it here. yeah I can definitely switch it to True in the example.
sub_input = ( | ||
torch.randn(input.shape) | ||
.to(dtype.to(input.dtype, t=torch.dtype)) | ||
.cuda() | ||
) | ||
sub_inputs.append(sub_input) | ||
|
||
compiled_func = torch._inductor.compile( | ||
submodule, | ||
sub_inputs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.dynamo.mark_dynamic
API is used to set dynamic shapes for torch.compile workflow. Reference: https://docs.pytorch.org/TensorRT/user_guide/dynamic_shapes.html. So you can use the construct_submodule_inputs() API to give you dynamic inputs (if that's the case) and set them to the inductor segment accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
submodule_inputs
is already the return of construct_submodule_inputs()
. What I did here is to convert torch-trt Input to torch Tensor. I changed to:
sub_input = (
input.torch_tensor
.to(dtype.to(input.dtype, t=torch.dtype))
.cuda()
)
else: | ||
raise ValueError(f"Unknown backend for submodule: {name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should there be a _run_on_gpu
segment here ? How would the torch.ops.aten._native_batch_norm_legit_no_training.default fallback to native pytorch in your example from slack ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, as I mentioned above, _run_on_gpu
is not considered as a backend. We just keep the module as is. Only _run_on_acc_backend
modules need to be replaced with the compiled module. This aligns with our existing implementation.
class FunctionWrapper(torch.nn.Module): | ||
def __init__(self, func): | ||
super().__init__() | ||
self.func = func | ||
|
||
def forward(self, *args, **kwargs): | ||
return self.func(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider naming this to InductorModule
and moving to utils
require_full_compilation=require_full_compilation, | ||
skip_fusion=skip_fusion, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think copying all this code is necessary ? or inheriting only some components would be sufficient ?
# Wrap the compiled function to be a torch.nn.Module | ||
compiled_submodule = FunctionWrapper(compiled_func) | ||
|
||
elif "_run_on_acc_tensorrt" in name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there some sort of design where the capability and conversion parts can be grouped and registered? We can add this concept to the RFC for later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That way we dont have a long conditional case set we just look up the appropriate conversion function on a standardized API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I agree that the conversion part of different backends should be grouped, but for now I don't have too much info about other backends (like how to convert an op to that backend). We can definitely do this when we are ready to support other backends.
if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows: | ||
getattr(partitioned_module, name).setup_engine() | ||
if use_hierarchical_partitioner: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar here, we should standardize post processing as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be the merge-able version of this PR since we are going to expose the new partitioner yet?
@narendasan As you said that we don't want to expose APIs of this feature to users for now, the current code doesn't change any behavior. It can be merged after I resolve the issues above. |
Description
Hierarchical Partitioner to support multi-backends.
Type of change
Checklist: