Skip to content

Commit a2564df

Browse files
authored
feat: cherry-pick of torch.compile dynamic shapes (#2750)
1 parent 67675d7 commit a2564df

File tree

8 files changed

+193
-118
lines changed

8 files changed

+193
-118
lines changed

.github/workflows/build-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
os: linux
2222
test-infra-repository: pytorch/test-infra
2323
test-infra-ref: main
24+
channel: test
2425
with-rocm: false
2526
with-cpu: false
2627

@@ -208,6 +209,7 @@ jobs:
208209
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
209210
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
210211
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
212+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
211213
popd
212214
213215
tests-py-dynamo-core:

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
273273
return False
274274
return True
275275

276-
# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
276+
# Check if the module has metadata (shape, dtype).
277277
if not contains_metadata(gm):
278-
from torch._inductor.compile_fx import fake_tensor_prop
279-
280-
torch_inputs = get_torch_inputs(sample_inputs, settings.device)
281-
with torch.no_grad():
282-
# This fails if the module has data-dependent shape operators.
283-
fake_tensor_prop(gm, torch_inputs)
278+
# TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
279+
logger.warning(
280+
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
281+
)
284282

285283
# Partition module into components that can be TRT-accelerated
286284
fast_partitioner_failed = False

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch_tensorrt.dynamo.lowering import (
1414
apply_lowering_passes,
1515
get_decompositions,
16+
remove_sym_nodes,
1617
repair_input_aliasing,
1718
)
1819
from torch_tensorrt.dynamo.utils import (
@@ -27,7 +28,7 @@
2728
@td.register_backend(name="tensorrt") # type: ignore[misc]
2829
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
2930
def torch_tensorrt_backend(
30-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
31+
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
3132
) -> torch.nn.Module:
3233
# Set log level at the top of compilation (torch_tensorrt.dynamo)
3334
if (
@@ -44,15 +45,15 @@ def torch_tensorrt_backend(
4445

4546
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
4647
def aot_torch_tensorrt_aten_backend(
47-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
48+
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
4849
) -> torch.nn.Module:
4950
settings = parse_dynamo_kwargs(kwargs)
5051
return _pretraced_backend(gm, sample_inputs, settings)
5152

5253

5354
def _pretraced_backend(
5455
gm: torch.fx.GraphModule,
55-
sample_inputs: Sequence[torch.Tensor],
56+
sample_inputs: Sequence[Any],
5657
settings: CompilationSettings = CompilationSettings(),
5758
) -> torch.fx.GraphModule | Callable[..., Any]:
5859
"""Helper function to manage translation of traced FX module to TRT engines
@@ -74,10 +75,17 @@ def _pretraced_backend(
7475
fake_mode, "allow_non_fake_inputs", True
7576
), fake_mode:
7677
repair_input_aliasing(gm)
78+
79+
# Remove sym_int placeholders and inputs
80+
remove_sym_nodes(gm)
81+
torch_inputs = [
82+
input for input in sample_inputs if isinstance(input, torch.Tensor)
83+
]
84+
7785
# Invoke AOTAutograd to translate operators to aten
7886
gm = aot_export_joint_simple(
7987
gm,
80-
sample_inputs,
88+
torch_inputs,
8189
trace_joint=False,
8290
decompositions=get_decompositions(
8391
settings.enable_experimental_decompositions
@@ -86,10 +94,10 @@ def _pretraced_backend(
8694

8795
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8896

89-
gm = apply_lowering_passes(gm, sample_inputs)
97+
gm = apply_lowering_passes(gm, torch_inputs)
9098

9199
torchtrt_inputs = prepare_inputs(
92-
sample_inputs, disable_memory_format_check=True
100+
torch_inputs, disable_memory_format_check=True
93101
)
94102
trt_compiled = compile_module(
95103
gm,

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def get_shape_with_dynamic_shape(
104104
scale_res = scale_layer.get_output(0)
105105

106106
length = input_shape.shape[0]
107+
107108
zero_layer = ctx.net.add_constant(
108109
input_shape.shape, np.zeros((length), dtype=np.int32)
109110
)

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
torch_enabled_decompositions,
44
)
55
from ._decompositions import get_decompositions # noqa: F401
6-
from ._fusers import * # noqa: F401
6+
from ._remove_sym_nodes import remove_sym_nodes
77
from ._repair_input_aliasing import repair_input_aliasing
88
from .passes import apply_lowering_passes

py/torch_tensorrt/dynamo/lowering/_fusers.py

Lines changed: 0 additions & 82 deletions
This file was deleted.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9+
"""Remove sym_int placeholders which get inserted due to torch.compile's
10+
dynamic=True behavior
11+
"""
12+
# Extract SymInt placeholder Tensors
13+
placeholders = [
14+
node
15+
for node in gm.graph.nodes
16+
if (
17+
node.op == "placeholder"
18+
and isinstance(node.type, type)
19+
and issubclass(node.type, torch.SymInt)
20+
)
21+
]
22+
23+
for node in placeholders:
24+
gm.graph.erase_node(node)
25+
26+
gm.graph.lint()
27+
gm.recompile()
28+
logger.debug(f"Removed SymInt placeholders:\n{gm.graph}")
29+
30+
return gm

0 commit comments

Comments
 (0)