|
| 1 | +# Hierarchical Partitioner for TensorRT |
| 2 | + |
| 3 | +The Hierarchical Partitioner is an extension to the standard TensorRT partitioner that allows for more sophisticated partitioning strategies by considering backend priority and operator support. This is particularly useful when you want to distribute different parts of your model across multiple backends based on their capabilities and priorities. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +The Hierarchical Partitioner extends the standard TensorRT partitioner with the following capabilities: |
| 8 | + |
| 9 | +1. **Backend priority ordering**: Assign operators to backends based on a priority order, ensuring that operators are assigned to the highest-priority backend that supports them. |
| 10 | +2. **Multi-backend support**: Distribute model execution across multiple backends based on operator support. |
| 11 | +3. **Flexible block size requirements**: Specify minimum block sizes for TensorRT acceleration. |
| 12 | + |
| 13 | +## Usage |
| 14 | + |
| 15 | +### Basic Usage |
| 16 | + |
| 17 | +```python |
| 18 | +import torch |
| 19 | +import torch_tensorrt as torchtrt |
| 20 | +from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import hierarchical_partition |
| 21 | + |
| 22 | +# Create and trace your model |
| 23 | +model = YourModel() |
| 24 | +graph_module = torch.export.export(model, example_input).module() |
| 25 | + |
| 26 | +# Define backend support map |
| 27 | +backend_support_map = { |
| 28 | + "backend_1": {"op_x", "op_y", "op_z"}, |
| 29 | + "backend_2": {"op_a", "op_b", "op_c"}, |
| 30 | + "backend_3": {"op_c", "op_x"} |
| 31 | +} |
| 32 | + |
| 33 | +# Define backend priority (from highest to lowest) |
| 34 | +backend_priority = ["backend_3", "backend_1", "backend_2"] |
| 35 | + |
| 36 | +torch_executed_ops=[ |
| 37 | + "op_1", "op_2" |
| 38 | +] |
| 39 | + |
| 40 | +# Partition the model using the hierarchical partitioner |
| 41 | +partitioned_model, op_support = hierarchical_partition( |
| 42 | + graph_module, |
| 43 | + verbose=True, |
| 44 | + min_block_size=1, |
| 45 | + backend_support_map=backend_support_map, |
| 46 | + backend_priority=backend_priority, |
| 47 | + torch_executed_ops=torch_executed_ops, |
| 48 | +) |
| 49 | +``` |
| 50 | + |
| 51 | +### Backend Support Map |
| 52 | + |
| 53 | +The `backend_support_map` parameter is a dictionary that maps backend names to sets of supported operators. Each backend is defined by the set of operators it can execute. |
| 54 | + |
| 55 | +Example: |
| 56 | + |
| 57 | +```python |
| 58 | +backend_support_map = { |
| 59 | + "mlir": { |
| 60 | + torch.ops.aten.conv2d.default, |
| 61 | + torch.ops.aten.convolution.default, |
| 62 | + }, |
| 63 | + "inductor": { |
| 64 | + torch.ops.aten.relu.default, |
| 65 | + }, |
| 66 | + "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), |
| 67 | +} |
| 68 | +``` |
| 69 | + |
| 70 | +### Backend Priority |
| 71 | + |
| 72 | +The `backend_priority` parameter is an ordered list of backend names, from highest to lowest priority. When an operator is supported by multiple backends, it will be assigned to the highest-priority backend that supports it. |
| 73 | + |
| 74 | +Example: |
| 75 | + |
| 76 | +```python |
| 77 | +backend_priority = ["mlir", "tensorrt", "inductor"] |
| 78 | +``` |
| 79 | + |
| 80 | +### Complete Example |
| 81 | + |
| 82 | +Here's a complete example of how to use the hierarchical partitioner with a simple model: |
| 83 | + |
| 84 | +```python |
| 85 | +import torch |
| 86 | +import torch.nn as nn |
| 87 | +import torch_tensorrt as torchtrt |
| 88 | +from torch_tensorrt.dynamo import CompilationSettings |
| 89 | +from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import hierarchical_partition |
| 90 | +from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_ATEN_CONVERTERS |
| 91 | + |
| 92 | + |
| 93 | +# Define a simple model |
| 94 | +class SimpleModel(nn.Module): |
| 95 | + def __init__(self): |
| 96 | + super().__init__() |
| 97 | + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
| 98 | + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
| 99 | + self.fc = nn.Linear(128 * 56 * 56, 10) |
| 100 | + |
| 101 | + def forward(self, x): |
| 102 | + x = torch.relu(self.conv1(x)) |
| 103 | + x = torch.relu(self.conv2(x)) |
| 104 | + x = x.view(x.size(0), -1) |
| 105 | + x = self.fc(x) |
| 106 | + return x |
| 107 | + |
| 108 | +def main(): |
| 109 | + # Create model |
| 110 | + model = SimpleModel() |
| 111 | + model.eval() |
| 112 | + |
| 113 | + # Create example input |
| 114 | + example_input = torch.randn(1, 3, 224, 224).cuda() |
| 115 | + |
| 116 | + # Get GraphModule |
| 117 | + exported_program = torch.export.export(model, (example_input,)) |
| 118 | + exported_program = pre_export_lowering(exported_program) |
| 119 | + exported_program = exported_program.run_decompositions( |
| 120 | + get_decompositions() |
| 121 | + ) |
| 122 | + gm = exported_program.module() |
| 123 | + |
| 124 | + # Define backend support map |
| 125 | + backend_support_map = { |
| 126 | + "mlir": { |
| 127 | + torch.ops.aten.conv2d.default, |
| 128 | + torch.ops.aten.convolution.default, |
| 129 | + }, |
| 130 | + "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), |
| 131 | + } |
| 132 | + |
| 133 | + # Define backend priority |
| 134 | + backend_priority = ["mlir", "tensorrt"] |
| 135 | + |
| 136 | + # Define operations executed by torch |
| 137 | + torch_executed_ops=[ |
| 138 | + torch.ops.aten._native_batch_norm_legit_no_training.default |
| 139 | + ] |
| 140 | + |
| 141 | + # Partition the model using the hierarchical partitioner |
| 142 | + partitioned_model, op_support = hierarchical_partition( |
| 143 | + gm, |
| 144 | + verbose=True, |
| 145 | + min_block_size=1, |
| 146 | + backend_support_map=backend_support_map, |
| 147 | + backend_priority=backend_priority, |
| 148 | + torch_executed_ops=torch_executed_ops, |
| 149 | + ) |
| 150 | + |
| 151 | + # Print the partitioned model |
| 152 | + print("\nPartitioned Model Structure:") |
| 153 | + print(partitioned_model) |
| 154 | + |
| 155 | + # Run inference with the partitioned model |
| 156 | + with torch.no_grad(): |
| 157 | + output = partitioned_model(example_input) |
| 158 | + print("Partitioned output:", output) |
| 159 | + print("Partitioned output == original output:", torch.allclose(model(example_input), output, 1e-2, 1e-2)) |
| 160 | + |
| 161 | + |
| 162 | +if __name__ == "__main__": |
| 163 | + main() |
| 164 | +``` |
| 165 | + |
| 166 | +## Implementation Details |
| 167 | + |
| 168 | +The Hierarchical Partitioner consists of two main components: |
| 169 | + |
| 170 | +1. **BackendOpSupportTester**: Extends the standard `OpSupportTester` to be aware of backend support and priority, assigning each node to the highest-priority backend that supports it. |
| 171 | + |
| 172 | +2. **HierarchicalTRTPartitioner**: Extends the standard `TRTPartitioner` to partition the graph based on backend assignments. |
| 173 | + |
| 174 | +The partitioner works by: |
| 175 | + |
| 176 | +1. Identifying which backend each node should be assigned to based on operator support and priority |
| 177 | +2. Partitioning the graph based on these backend assignments |
| 178 | +3. Applying minimum block size requirements when removing small subgraphs |
| 179 | + |
| 180 | +## Advanced Usage |
| 181 | + |
| 182 | +### Custom Backend Support |
| 183 | + |
| 184 | +You can define custom backends with specific operator support by modifying the `backend_support_map` parameter. This allows you to create specialized backends for specific types of operations. |
| 185 | + |
| 186 | +### Default Backend Support |
| 187 | + |
| 188 | +If no `backend_support_map` is provided, the partitioner will use a default map with two backends: |
| 189 | +- `tensorrt`: Supports all operators in the TensorRT converter registry |
| 190 | +- `pytorch`: Supports all remaining operators |
| 191 | + |
| 192 | +Similarly, if no `backend_priority` is provided, the default priority order is `["tensorrt", "pytorch"]`. |
| 193 | + |
| 194 | +## Limitations |
| 195 | + |
| 196 | +- The current implementation does not support dynamic shapes or control flow. |
| 197 | +- The partitioner does not currently support custom operators or fusion patterns. |
| 198 | +- The partitioner assumes that operators are atomic and cannot be split across backends. |
| 199 | + |
| 200 | +## Future Work |
| 201 | + |
| 202 | +- Support for dynamic shapes and control flow |
| 203 | +- Support for custom operators and fusion patterns |
| 204 | +- Integration with the TensorRT profiler to optimize partitioning based on performance metrics |
0 commit comments