Skip to content

feat: Hierarchical Partitioner to support multi-backends #3539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

zewenli98
Copy link
Collaborator

Description

Hierarchical Partitioner to support multi-backends.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this May 29, 2025
@zewenli98 zewenli98 marked this pull request as draft May 29, 2025 22:05
@github-actions github-actions bot added documentation Improvements or additions to documentation component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 29, 2025
@github-actions github-actions bot requested a review from peri044 May 29, 2025 22:05
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py	2025-05-29 22:05:48.305564+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py	2025-05-29 22:06:09.728061+00:00
@@ -6,28 +6,31 @@
    post_lowering,
    pre_export_lowering,
)
import torchvision.models as models
from torch_tensorrt.dynamo import CompilationSettings
-from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import hierarchical_partition
+from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
+    hierarchical_partition,
+)
from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition
+
# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition
import operator
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
    DYNAMO_CONVERTERS as CONVERTERS,
-    DYNAMO_ATEN_CONVERTERS
+    DYNAMO_ATEN_CONVERTERS,
)


class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
-        
+
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x)
@@ -39,25 +42,22 @@
def main():
    # Create model
    model = SimpleModel().cuda()
    # model = models.efficientnet_b0(pretrained=True).cuda()
    model = model.eval()
-    
+
    # Create example input
    example_input = torch.randn(1, 3, 224, 224).cuda()
-    
+
    exported_program = torch.export.export(model, (example_input,))
    exported_program = pre_export_lowering(exported_program)
-    exported_program = exported_program.run_decompositions(
-        get_decompositions()
-    )
+    exported_program = exported_program.run_decompositions(get_decompositions())

    gm = exported_program.module()
-    
+
    print(gm.graph)
-    
-    
+
    # Partition the model using the adjacency partitioner
    # partitioned_model, op_support = partition(
    #     gm,
    #     verbose=True,
    #     min_block_size=1,
@@ -68,16 +68,16 @@

    partitioned_model, op_support = hierarchical_partition(
        gm,
        verbose=True,
        min_block_size=1,
-        backend_priority=["mlir", "tensorrt"], #, "inductor"],
+        backend_priority=["mlir", "tensorrt"],  # , "inductor"],
        backend_support_map={
            "mlir": {
                # operator.getitem,
-                torch.ops.aten.conv2d.default, 
-                torch.ops.aten.convolution.default, 
+                torch.ops.aten.conv2d.default,
+                torch.ops.aten.convolution.default,
            },
            "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
            # "inductor": {
            #     torch.ops.aten.relu.default,
            # },
@@ -86,18 +86,20 @@
            torch.ops.aten._native_batch_norm_legit_no_training.default
        ],
        require_full_compilation=False,
        skip_fusion=False,
    )
-    
+
    print("\nPartitioned Model Structure:")
    print(partitioned_model)

    with torch.no_grad():
        output = partitioned_model(example_input)
        print("Partitioned output:", output)
-        print("Partitioned output == original output:", torch.allclose(model(example_input), output, 1e-2, 1e-2))
-        
+        print(
+            "Partitioned output == original output:",
+            torch.allclose(model(example_input), output, 1e-2, 1e-2),
+        )


if __name__ == "__main__":
    main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py	2025-05-29 22:05:48.319565+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py	2025-05-29 22:06:11.554604+00:00
@@ -3,10 +3,11 @@
from collections import defaultdict
import torch
import torch.fx.passes.operator_support as ops
from torch.fx.node import Target, Node
from torch.fx.graph_module import GraphModule
+
# from torch.fx.passes.splitter_base import (
#     FxNetAccFusionsFinder,
#     FxNetAccNodesFinder,
#     Subgraph,
#     _SplitterBase,
@@ -25,11 +26,11 @@
    MIN_BLOCK_SIZE,
    REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
    DYNAMO_CONVERTERS as CONVERTERS,
-    DYNAMO_ATEN_CONVERTERS
+    DYNAMO_ATEN_CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
from torch._ops import OpOverload
from torch.fx.node import _get_qualified_name
import torch_tensorrt.dynamo.partitioning._adjacency_partitioner as _adjacency_partitioner
@@ -38,31 +39,36 @@


class BackendOpSupportTester(ops.OperatorSupportBase):  # type: ignore
    """Class to determine whether operators are supported by specific backends"""

-    def __init__(self, backend_support_map: Dict[str, Set[OpOverload]], backend_priority: List[str], torch_executed_ops: Collection[Target] = set()) -> None:
+    def __init__(
+        self,
+        backend_support_map: Dict[str, Set[OpOverload]],
+        backend_priority: List[str],
+        torch_executed_ops: Collection[Target] = set(),
+    ) -> None:
        super().__init__()

        # Initialize sets of supported/unsupported operators
        self.supported_operators: Dict[str, int] = {}
        self.unsupported_operators: Dict[str, int] = {}
        self.torch_executed_ops = torch_executed_ops
        # Map of backend names to sets of supported operators
        self.backend_support_map = backend_support_map
        # Ordered list of backend names, from highest to lowest priority
        self.backend_priority = backend_priority
-        
+
    def is_node_supported(
        self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
    ) -> Tuple[bool, Optional[str]]:
        node_name = ConverterRegistry.qualified_name_or_str(node.target)
-        
+
        for i, backend_name in enumerate(self.backend_priority):
            supported_ops = self.backend_support_map.get(backend_name, set())
            supported_ops = set(_get_qualified_name(op) for op in supported_ops)
-            
+
            if (
                (node_name in supported_ops or node.op == "get_attr")
                and node_name not in self.torch_executed_ops
                and node.target not in self.torch_executed_ops
            ):
@@ -106,11 +112,11 @@
            logger.debug("\nAll Nodes Supported\n")


class HierarchicalTRTPartitioner(_SplitterBase):  # type: ignore
    """Hierarchical partitioner to split an FX graph into subgraphs based on backend priority
-    
+
    This partitioner extends the TRTPartitioner of adjacency_partitioner with backend priority awareness,
    allowing different parts of the model to be executed on different backends based on
    operator support and priority ordering.

    Args:
@@ -155,11 +161,11 @@

        # Get all accelerated nodes based on operator support conditions
        self.acc_nodes = FxNetAccNodesFinder(
            self.module, self.operator_support, self.settings.allow_non_tensor
        )()
-        
+
        if self.settings.skip_fusion:
            self.fusions = {}
        else:
            self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))()

@@ -200,11 +206,13 @@
                    logger.debug(
                        "Eliminating acc subgraph because it's smaller than the threshold: "
                        f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
                    )
                    # if the last subgraph result[-1] is non-acc or has the same backend, merge the current subgraph into it
-                    if result and (not result[-1].is_acc or result[-1].backend == subgraph.backend):
+                    if result and (
+                        not result[-1].is_acc or result[-1].backend == subgraph.backend
+                    ):
                        result[-1].nodes.extend(subgraph.nodes)
                    else:
                        # if the last subgraph result[-1] has different backends, then append the current subgraph as non-acc
                        subgraph.is_acc = False
                        subgraph.backend = "None"
@@ -265,19 +273,19 @@
        """Generates starter nodes for partitioning + segmentation"""
        # Starter accelerated nodes are all callable accelerated ops
        starter_acc_nodes = {
            node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS
        }
-        
+
        # Started non-accelerated nodes are the rest of the callable nodes
        starter_non_acc_nodes = {
            node
            for node in self.module.graph.nodes
            if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS)
        }
        return starter_non_acc_nodes, starter_acc_nodes
-    
+
    def put_nodes_into_subgraphs(self) -> list[Subgraph]:
        # We start graph traversal from leaf nodes
        current_cpu_nodes, current_acc_nodes = self.starter_nodes()
        visited_nodes: NodeSet = set()

@@ -301,23 +309,38 @@
            if node is None:
                if not current_subgraph_nodes:
                    raise FxNetSplitterInternalError("Subgraph can't be empty")
                print(222222222222, current_subgraph_nodes, acc_subgraph)
                subgraphs.append(
-                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+                    Subgraph(
+                        is_acc=acc_subgraph,
+                        nodes=current_subgraph_nodes,
+                        backend=(
+                            current_subgraph_nodes[-1].backend
+                            if acc_subgraph
+                            else "None"
+                        ),
+                    )
                )
                acc_subgraph = not acc_subgraph
                current_subgraph_nodes = []
                continue

            # If the backend changed, then it's time to start a new subgraph
-            if current_subgraph_nodes and current_subgraph_nodes[-1].backend != node.backend:
+            if (
+                current_subgraph_nodes
+                and current_subgraph_nodes[-1].backend != node.backend
+            ):
                if not current_subgraph_nodes:
                    raise FxNetSplitterInternalError("Subgraph can't be empty")
                print(333333333333, current_subgraph_nodes, acc_subgraph)
                subgraphs.append(
-                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend)
+                    Subgraph(
+                        is_acc=acc_subgraph,
+                        nodes=current_subgraph_nodes,
+                        backend=current_subgraph_nodes[-1].backend,
+                    )
                )
                current_subgraph_nodes = []
                continue

            current_nodes.remove(node)
@@ -343,11 +366,17 @@
                    current_cpu_nodes.add(user)

        # Check if the last subgraph was not created
        if current_subgraph_nodes:
            subgraphs.append(
-                Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+                Subgraph(
+                    is_acc=acc_subgraph,
+                    nodes=current_subgraph_nodes,
+                    backend=(
+                        current_subgraph_nodes[-1].backend if acc_subgraph else "None"
+                    ),
+                )
            )

        if not subgraphs:
            raise FxNetSplitterInternalError("Couldn't create subgraphs")
        print(666666666666, subgraphs)
@@ -384,22 +413,22 @@
    """
    # Ensure graph is clean prior to partitioning
    gm.graph.eliminate_dead_code()
    gm.graph.lint()
    gm.recompile()
-    
+
    # Default backend support map if none provided
    if backend_support_map is None:
        backend_support_map = {
            "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
            "inductor": set(),
        }

    # Default backend priority if none provided
    if backend_priority is None:
        backend_priority = ["tensorrt", "inductor"]
-    
+
    # Construct BackendOpSupportTester
    supported_ops = BackendOpSupportTester(
        backend_support_map=backend_support_map,
        backend_priority=backend_priority,
        torch_executed_ops=torch_executed_ops,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py	2025-05-29 22:05:48.320565+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py	2025-05-29 22:06:11.987143+00:00
@@ -187,11 +187,13 @@
    def __call__(self) -> NodeSet:
        submodules = dict(self.module.named_modules())
        for n in self.module.graph.nodes:
            n.backend = "None"
            if n.op in CALLABLE_NODE_OPS:
-                is_supported, backend = self.operator_support.is_node_supported(submodules, n)
+                is_supported, backend = self.operator_support.is_node_supported(
+                    submodules, n
+                )
                if is_supported:
                    n.backend = backend
                    self.acc_nodes.add(n)

        if not self.allow_non_tensor:

@zewenli98 zewenli98 requested a review from narendasan May 29, 2025 22:11
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py	2025-05-29 22:11:19.174161+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/hierarchical_partitioner_example.py	2025-05-29 22:11:40.163822+00:00
@@ -6,28 +6,31 @@
    post_lowering,
    pre_export_lowering,
)
import torchvision.models as models
from torch_tensorrt.dynamo import CompilationSettings
-from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import hierarchical_partition
+from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
+    hierarchical_partition,
+)
from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition
+
# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition
import operator
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
    DYNAMO_CONVERTERS as CONVERTERS,
-    DYNAMO_ATEN_CONVERTERS
+    DYNAMO_ATEN_CONVERTERS,
)


class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
-        
+
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x)
@@ -39,25 +42,22 @@
def main():
    # Create model
    model = SimpleModel().cuda()
    # model = models.efficientnet_b0(pretrained=True).cuda()
    model = model.eval()
-    
+
    # Create example input
    example_input = torch.randn(1, 3, 224, 224).cuda()
-    
+
    exported_program = torch.export.export(model, (example_input,))
    exported_program = pre_export_lowering(exported_program)
-    exported_program = exported_program.run_decompositions(
-        get_decompositions()
-    )
+    exported_program = exported_program.run_decompositions(get_decompositions())

    gm = exported_program.module()
-    
+
    print(gm.graph)
-    
-    
+
    # Partition the model using the adjacency partitioner
    # partitioned_model, op_support = partition(
    #     gm,
    #     verbose=True,
    #     min_block_size=1,
@@ -68,16 +68,16 @@

    partitioned_model, op_support = hierarchical_partition(
        gm,
        verbose=True,
        min_block_size=1,
-        backend_priority=["mlir", "tensorrt"], #, "inductor"],
+        backend_priority=["mlir", "tensorrt"],  # , "inductor"],
        backend_support_map={
            "mlir": {
                # operator.getitem,
-                torch.ops.aten.conv2d.default, 
-                torch.ops.aten.convolution.default, 
+                torch.ops.aten.conv2d.default,
+                torch.ops.aten.convolution.default,
            },
            "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
            # "inductor": {
            #     torch.ops.aten.relu.default,
            # },
@@ -86,18 +86,20 @@
            torch.ops.aten._native_batch_norm_legit_no_training.default
        ],
        require_full_compilation=False,
        skip_fusion=False,
    )
-    
+
    print("\nPartitioned Model Structure:")
    print(partitioned_model)

    with torch.no_grad():
        output = partitioned_model(example_input)
        print("Partitioned output:", output)
-        print("Partitioned output == original output:", torch.allclose(model(example_input), output, 1e-2, 1e-2))
-        
+        print(
+            "Partitioned output == original output:",
+            torch.allclose(model(example_input), output, 1e-2, 1e-2),
+        )


if __name__ == "__main__":
    main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py	2025-05-29 22:11:19.189162+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py	2025-05-29 22:11:41.944938+00:00
@@ -3,10 +3,11 @@
from collections import defaultdict
import torch
import torch.fx.passes.operator_support as ops
from torch.fx.node import Target, Node
from torch.fx.graph_module import GraphModule
+
# from torch.fx.passes.splitter_base import (
#     FxNetAccFusionsFinder,
#     FxNetAccNodesFinder,
#     Subgraph,
#     _SplitterBase,
@@ -25,11 +26,11 @@
    MIN_BLOCK_SIZE,
    REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
    DYNAMO_CONVERTERS as CONVERTERS,
-    DYNAMO_ATEN_CONVERTERS
+    DYNAMO_ATEN_CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
from torch._ops import OpOverload
from torch.fx.node import _get_qualified_name
import torch_tensorrt.dynamo.partitioning._adjacency_partitioner as _adjacency_partitioner
@@ -38,31 +39,36 @@


class BackendOpSupportTester(ops.OperatorSupportBase):  # type: ignore
    """Class to determine whether operators are supported by specific backends"""

-    def __init__(self, backend_support_map: Dict[str, Set[OpOverload]], backend_priority: List[str], torch_executed_ops: Collection[Target] = set()) -> None:
+    def __init__(
+        self,
+        backend_support_map: Dict[str, Set[OpOverload]],
+        backend_priority: List[str],
+        torch_executed_ops: Collection[Target] = set(),
+    ) -> None:
        super().__init__()

        # Initialize sets of supported/unsupported operators
        self.supported_operators: Dict[str, int] = {}
        self.unsupported_operators: Dict[str, int] = {}
        self.torch_executed_ops = torch_executed_ops
        # Map of backend names to sets of supported operators
        self.backend_support_map = backend_support_map
        # Ordered list of backend names, from highest to lowest priority
        self.backend_priority = backend_priority
-        
+
    def is_node_supported(
        self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
    ) -> Tuple[bool, Optional[str]]:
        node_name = ConverterRegistry.qualified_name_or_str(node.target)
-        
+
        for i, backend_name in enumerate(self.backend_priority):
            supported_ops = self.backend_support_map.get(backend_name, set())
            supported_ops = set(_get_qualified_name(op) for op in supported_ops)
-            
+
            if (
                (node_name in supported_ops or node.op == "get_attr")
                and node_name not in self.torch_executed_ops
                and node.target not in self.torch_executed_ops
            ):
@@ -106,11 +112,11 @@
            logger.debug("\nAll Nodes Supported\n")


class HierarchicalTRTPartitioner(_SplitterBase):  # type: ignore
    """Hierarchical partitioner to split an FX graph into subgraphs based on backend priority
-    
+
    This partitioner extends the TRTPartitioner of adjacency_partitioner with backend priority awareness,
    allowing different parts of the model to be executed on different backends based on
    operator support and priority ordering.

    Args:
@@ -155,11 +161,11 @@

        # Get all accelerated nodes based on operator support conditions
        self.acc_nodes = FxNetAccNodesFinder(
            self.module, self.operator_support, self.settings.allow_non_tensor
        )()
-        
+
        if self.settings.skip_fusion:
            self.fusions = {}
        else:
            self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))()

@@ -200,11 +206,13 @@
                    logger.debug(
                        "Eliminating acc subgraph because it's smaller than the threshold: "
                        f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
                    )
                    # if the last subgraph result[-1] is non-acc or has the same backend, merge the current subgraph into it
-                    if result and (not result[-1].is_acc or result[-1].backend == subgraph.backend):
+                    if result and (
+                        not result[-1].is_acc or result[-1].backend == subgraph.backend
+                    ):
                        result[-1].nodes.extend(subgraph.nodes)
                    else:
                        # if the last subgraph result[-1] has different backends, then append the current subgraph as non-acc
                        subgraph.is_acc = False
                        subgraph.backend = "None"
@@ -265,19 +273,19 @@
        """Generates starter nodes for partitioning + segmentation"""
        # Starter accelerated nodes are all callable accelerated ops
        starter_acc_nodes = {
            node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS
        }
-        
+
        # Started non-accelerated nodes are the rest of the callable nodes
        starter_non_acc_nodes = {
            node
            for node in self.module.graph.nodes
            if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS)
        }
        return starter_non_acc_nodes, starter_acc_nodes
-    
+
    def put_nodes_into_subgraphs(self) -> list[Subgraph]:
        # We start graph traversal from leaf nodes
        current_cpu_nodes, current_acc_nodes = self.starter_nodes()
        visited_nodes: NodeSet = set()

@@ -301,23 +309,38 @@
            if node is None:
                if not current_subgraph_nodes:
                    raise FxNetSplitterInternalError("Subgraph can't be empty")
                print(222222222222, current_subgraph_nodes, acc_subgraph)
                subgraphs.append(
-                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+                    Subgraph(
+                        is_acc=acc_subgraph,
+                        nodes=current_subgraph_nodes,
+                        backend=(
+                            current_subgraph_nodes[-1].backend
+                            if acc_subgraph
+                            else "None"
+                        ),
+                    )
                )
                acc_subgraph = not acc_subgraph
                current_subgraph_nodes = []
                continue

            # If the backend changed, then it's time to start a new subgraph
-            if current_subgraph_nodes and current_subgraph_nodes[-1].backend != node.backend:
+            if (
+                current_subgraph_nodes
+                and current_subgraph_nodes[-1].backend != node.backend
+            ):
                if not current_subgraph_nodes:
                    raise FxNetSplitterInternalError("Subgraph can't be empty")
                print(333333333333, current_subgraph_nodes, acc_subgraph)
                subgraphs.append(
-                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend)
+                    Subgraph(
+                        is_acc=acc_subgraph,
+                        nodes=current_subgraph_nodes,
+                        backend=current_subgraph_nodes[-1].backend,
+                    )
                )
                current_subgraph_nodes = []
                continue

            current_nodes.remove(node)
@@ -343,11 +366,17 @@
                    current_cpu_nodes.add(user)

        # Check if the last subgraph was not created
        if current_subgraph_nodes:
            subgraphs.append(
-                Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes, backend=current_subgraph_nodes[-1].backend if acc_subgraph else "None")
+                Subgraph(
+                    is_acc=acc_subgraph,
+                    nodes=current_subgraph_nodes,
+                    backend=(
+                        current_subgraph_nodes[-1].backend if acc_subgraph else "None"
+                    ),
+                )
            )

        if not subgraphs:
            raise FxNetSplitterInternalError("Couldn't create subgraphs")
        print(666666666666, subgraphs)
@@ -384,22 +413,22 @@
    """
    # Ensure graph is clean prior to partitioning
    gm.graph.eliminate_dead_code()
    gm.graph.lint()
    gm.recompile()
-    
+
    # Default backend support map if none provided
    if backend_support_map is None:
        backend_support_map = {
            "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
            "inductor": set(),
        }

    # Default backend priority if none provided
    if backend_priority is None:
        backend_priority = ["tensorrt", "inductor"]
-    
+
    # Construct BackendOpSupportTester
    supported_ops = BackendOpSupportTester(
        backend_support_map=backend_support_map,
        backend_priority=backend_priority,
        torch_executed_ops=torch_executed_ops,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py	2025-05-29 22:11:19.189162+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/splitter_base.py	2025-05-29 22:11:42.239192+00:00
@@ -187,11 +187,13 @@
    def __call__(self) -> NodeSet:
        submodules = dict(self.module.named_modules())
        for n in self.module.graph.nodes:
            n.backend = "None"
            if n.op in CALLABLE_NODE_OPS:
-                is_supported, backend = self.operator_support.is_node_supported(submodules, n)
+                is_supported, backend = self.operator_support.is_node_supported(
+                    submodules, n
+                )
                if is_supported:
                    n.backend = backend
                    self.acc_nodes.add(n)

        if not self.allow_non_tensor:

@zewenli98 zewenli98 force-pushed the multi_backend_partitioner branch 2 times, most recently from 1872b7c to 79d2616 Compare May 29, 2025 22:25
@zewenli98 zewenli98 force-pushed the multi_backend_partitioner branch from 79d2616 to d97f668 Compare May 29, 2025 23:17
@github-actions github-actions bot removed the documentation Improvements or additions to documentation label May 29, 2025
@zewenli98 zewenli98 marked this pull request as ready for review June 2, 2025 22:01
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering how this feature will be made available, the min_block_size and torch_executed_ops need to be re-thought or deprecated as the min_block_size can apply to all backends and torch_executed_ops will be replaced by backend_support_map

torch.ops.aten._native_batch_norm_legit_no_training.default
],
require_full_compilation=False,
skip_fusion=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_fusion=False slows down the partitioning a lot. Can you check if it's really needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the min_block_size and torch_executed_ops need to be re-thought or deprecated as the min_block_size can apply to all backends

My understanding is that, if the num of ops of a GM is less than min_block_size, no matter it's TRT or other backend, the GM would not be compiled. Do you mean it shouldn't apply to other backends?

torch_executed_ops will be replaced by backend_support_map

My understanding is that torch execute is not considered as a backend because it doesn't need any compilation, it just runs ops in eager mode. So, if an op was in torch_executed_ops, it would ignore backend_support_map and run in torch eager anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_fusion=False slows down the partitioning a lot. Can you check if it's really needed ?

Since adjacency partitioner uses this flag, I just keep it here. yeah I can definitely switch it to True in the example.

Comment on lines +956 to +966
sub_input = (
torch.randn(input.shape)
.to(dtype.to(input.dtype, t=torch.dtype))
.cuda()
)
sub_inputs.append(sub_input)

compiled_func = torch._inductor.compile(
submodule,
sub_inputs,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.dynamo.mark_dynamic API is used to set dynamic shapes for torch.compile workflow. Reference: https://docs.pytorch.org/TensorRT/user_guide/dynamic_shapes.html. So you can use the construct_submodule_inputs() API to give you dynamic inputs (if that's the case) and set them to the inductor segment accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

submodule_inputs is already the return of construct_submodule_inputs(). What I did here is to convert torch-trt Input to torch Tensor. I changed to:

sub_input = (
    input.torch_tensor
    .to(dtype.to(input.dtype, t=torch.dtype))
    .cuda()
)

Comment on lines +978 to +979
else:
raise ValueError(f"Unknown backend for submodule: {name}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be a _run_on_gpu segment here ? How would the torch.ops.aten._native_batch_norm_legit_no_training.default fallback to native pytorch in your example from slack ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, as I mentioned above, _run_on_gpu is not considered as a backend. We just keep the module as is. Only _run_on_acc_backend modules need to be replaced with the compiled module. This aligns with our existing implementation.

Comment on lines +1348 to +1354
class FunctionWrapper(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func

def forward(self, *args, **kwargs):
return self.func(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider naming this to InductorModule and moving to utils

require_full_compilation=require_full_compilation,
skip_fusion=skip_fusion,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think copying all this code is necessary ? or inheriting only some components would be sufficient ?

# Wrap the compiled function to be a torch.nn.Module
compiled_submodule = FunctionWrapper(compiled_func)

elif "_run_on_acc_tensorrt" in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some sort of design where the capability and conversion parts can be grouped and registered? We can add this concept to the RFC for later

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way we dont have a long conditional case set we just look up the appropriate conversion function on a standardized API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I agree that the conversion part of different backends should be grouped, but for now I don't have too much info about other backends (like how to convert an op to that backend). We can definitely do this when we are ready to support other backends.

if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows:
getattr(partitioned_module, name).setup_engine()
if use_hierarchical_partitioner:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here, we should standardize post processing as well

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the merge-able version of this PR since we are going to expose the new partitioner yet?

@zewenli98
Copy link
Collaborator Author

What would be the merge-able version of this PR since we are going to expose the new partitioner yet?

@narendasan As you said that we don't want to expose APIs of this feature to users for now, the current code doesn't change any behavior. It can be merged after I resolve the issues above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants