Skip to content

Commit bb990fd

Browse files
authored
Added CPU offloading (#3452)
1 parent 09e865f commit bb990fd

13 files changed

+640
-18
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
pre_export_lowering,
3838
)
3939
from torch_tensorrt.dynamo.utils import (
40+
deallocate_module,
4041
get_flat_args_with_check,
4142
get_output_metadata,
4243
parse_graph_io,
@@ -98,6 +99,7 @@ def cross_compile_for_windows(
9899
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
99100
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
100101
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
102+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
101103
**kwargs: Any,
102104
) -> torch.fx.GraphModule:
103105
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -362,7 +364,18 @@ def cross_compile_for_windows(
362364
# Apply lowering on the graph module
363365
gm = post_lowering(gm, settings)
364366
logger.debug("Lowered Input graph: " + str(gm.graph))
365-
367+
# Move the weights in the state_dict to CPU
368+
if offload_module_to_cpu:
369+
deallocate_module(exported_program.module(), delete_module=False)
370+
logger.info(
371+
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
372+
)
373+
else:
374+
remaining_memory, total_memory = torch.cuda.mem_get_info()
375+
if remaining_memory < total_memory // 2:
376+
logger.warning(
377+
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
378+
)
366379
trt_gm = compile_module(
367380
gm,
368381
trt_arg_inputs,
@@ -421,6 +434,7 @@ def compile(
421434
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422435
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423436
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
437+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
424438
**kwargs: Any,
425439
) -> torch.fx.GraphModule:
426440
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -498,6 +512,7 @@ def compile(
498512
enable_weight_streaming (bool): Enable weight streaming.
499513
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
500514
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
515+
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
501516
**kwargs: Any,
502517
Returns:
503518
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -550,15 +565,6 @@ def compile(
550565
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
551566
)
552567

553-
if (
554-
not immutable_weights
555-
and not refit_identical_engine_weights
556-
and enable_weight_streaming
557-
):
558-
raise ValueError(
559-
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
560-
)
561-
562568
if (
563569
"enable_cross_compile_for_windows" in kwargs.keys()
564570
and kwargs["enable_cross_compile_for_windows"]
@@ -674,6 +680,7 @@ def compile(
674680
"enable_weight_streaming": enable_weight_streaming,
675681
"tiling_optimization_level": tiling_optimization_level,
676682
"l2_limit_for_tiling": l2_limit_for_tiling,
683+
"offload_module_to_cpu": offload_module_to_cpu,
677684
}
678685

679686
settings = CompilationSettings(**compilation_options)
@@ -690,6 +697,18 @@ def compile(
690697
gm = post_lowering(gm, settings)
691698
logger.debug("Lowered Input graph: " + str(gm.graph))
692699

700+
# Move the weights in the state_dict to CPU
701+
if offload_module_to_cpu:
702+
deallocate_module(exported_program.module(), delete_module=False)
703+
logger.info(
704+
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
705+
)
706+
else:
707+
remaining_memory, total_memory = torch.cuda.mem_get_info()
708+
if remaining_memory < total_memory // 2:
709+
logger.warning(
710+
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
711+
)
693712
trt_gm = compile_module(
694713
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
695714
)
@@ -820,6 +839,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
820839
trt_modules = {}
821840
# Iterate over all components that can be accelerated
822841
# Generate the corresponding TRT Module for those
842+
823843
for name, _ in partitioned_module.named_children():
824844
submodule = getattr(partitioned_module, name)
825845
# filter on the GraphModule
@@ -833,6 +853,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
833853
str(name),
834854
str(submodule.graph),
835855
)
856+
submodule.to(to_torch_device(settings.device))
836857
continue
837858

838859
if name not in submodule_node_dict:
@@ -964,6 +985,7 @@ def convert_exported_program_to_serialized_trt_engine(
964985
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
965986
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
966987
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
988+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
967989
**kwargs: Any,
968990
) -> bytes:
969991
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1147,6 +1169,7 @@ def convert_exported_program_to_serialized_trt_engine(
11471169
"enable_weight_streaming": enable_weight_streaming,
11481170
"tiling_optimization_level": tiling_optimization_level,
11491171
"l2_limit_for_tiling": l2_limit_for_tiling,
1172+
"offload_module_to_cpu": offload_module_to_cpu,
11501173
}
11511174

11521175
settings = CompilationSettings(**compilation_options)
@@ -1166,7 +1189,17 @@ def convert_exported_program_to_serialized_trt_engine(
11661189

11671190
# Configure user compilation settings to converters.
11681191
CONVERTERS.set_compilation_settings(settings)
1169-
1192+
if offload_module_to_cpu:
1193+
deallocate_module(exported_program.module(), delete_module=False)
1194+
logger.info(
1195+
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
1196+
)
1197+
else:
1198+
remaining_memory, total_memory = torch.cuda.mem_get_info()
1199+
if remaining_memory < total_memory // 2:
1200+
logger.warning(
1201+
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
1202+
)
11701203
try:
11711204
interpreter_result = interpret_module_to_result(
11721205
gm,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52+
OFFLOAD_MODULE_TO_CPU = False
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28+
OFFLOAD_MODULE_TO_CPU,
2829
OPTIMIZATION_LEVEL,
2930
PASS_THROUGH_BUILD_FAILURES,
3031
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
get_trt_tensor,
4646
to_torch,
4747
)
48-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
48+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
4949
from torch_tensorrt.fx.observer import Observer
5050
from torch_tensorrt.logging import TRT_LOGGER
5151

@@ -731,7 +731,8 @@ def run(
731731
self._create_timing_cache(
732732
builder_config, self.compilation_settings.timing_cache_path
733733
)
734-
734+
if self.compilation_settings.offload_module_to_cpu:
735+
deallocate_module(self.module)
735736
serialized_engine = self.builder.build_serialized_network(
736737
self.ctx.net, builder_config
737738
)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,14 @@ class Frameworks(Enum):
8484
}
8585

8686

87-
def delete_module(module: torch.fx.GraphModule) -> None:
87+
def deallocate_module(module: torch.fx.GraphModule, delete_module: bool = True) -> None:
8888
"""
8989
This is a helper function to delete the instance of module. We first move it to CPU and then
9090
delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call
9191
"""
9292
module.to(CPU_DEVICE)
93-
del module
93+
if delete_module:
94+
del module
9495
torch.cuda.empty_cache()
9596
gc.collect()
9697

tests/py/dynamo/models/test_export_serde.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch
77
import torch_tensorrt as torchtrt
88
import torchvision.models as models
9-
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
9+
from torch_tensorrt.dynamo.utils import (
10+
COSINE_THRESHOLD,
11+
cosine_similarity,
12+
get_model_device,
13+
)
1014

1115
assertions = unittest.TestCase()
1216

@@ -283,6 +287,53 @@ def test_resnet18(ir):
283287
)
284288

285289

290+
@pytest.mark.unit
291+
def test_resnet18_cpu_offload(ir):
292+
"""
293+
This tests export save and load functionality on Resnet18 model
294+
"""
295+
model = models.resnet18().eval().cuda()
296+
input = torch.randn((1, 3, 224, 224)).to("cuda")
297+
298+
compile_spec = {
299+
"inputs": [
300+
torchtrt.Input(
301+
input.shape, dtype=torch.float, format=torch.contiguous_format
302+
)
303+
],
304+
"ir": ir,
305+
"min_block_size": 1,
306+
"cache_built_engines": False,
307+
"reuse_cached_engines": False,
308+
"offload_module_to_cpu": True,
309+
}
310+
311+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
312+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
313+
assertions.assertTrue(
314+
get_model_device(model).type == "cpu",
315+
msg="Model should be offloaded to CPU",
316+
)
317+
model.cuda()
318+
torchtrt.save(trt_module, trt_ep_path)
319+
320+
deser_trt_module = torchtrt.load(trt_ep_path).module()
321+
outputs_pyt = model(input)
322+
outputs_trt = trt_module(input)
323+
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
324+
assertions.assertTrue(
325+
cos_sim > COSINE_THRESHOLD,
326+
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
327+
)
328+
329+
outputs_trt_deser = deser_trt_module(input)
330+
cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
331+
assertions.assertTrue(
332+
cos_sim > COSINE_THRESHOLD,
333+
msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
334+
)
335+
336+
286337
@pytest.mark.unit
287338
def test_resnet18_dynamic(ir):
288339
"""
@@ -381,6 +432,67 @@ def forward(self, x):
381432
)
382433

383434

435+
@pytest.mark.unit
436+
def test_hybrid_conv_fallback_cpu_offload(ir):
437+
"""
438+
This tests export save and load functionality on a hybrid
439+
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
440+
"""
441+
442+
class MyModule(torch.nn.Module):
443+
def __init__(self):
444+
super().__init__()
445+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
446+
self.relu = torch.nn.ReLU()
447+
448+
def forward(self, x):
449+
conv = self.conv(x)
450+
relu = self.relu(conv)
451+
mul = relu * 0.5
452+
return mul
453+
454+
model = MyModule().eval().cuda()
455+
input = torch.randn((1, 3, 224, 224)).to("cuda")
456+
457+
compile_spec = {
458+
"inputs": [
459+
torchtrt.Input(
460+
input.shape, dtype=torch.float, format=torch.contiguous_format
461+
)
462+
],
463+
"ir": ir,
464+
"min_block_size": 1,
465+
"torch_executed_ops": {"torch.ops.aten.convolution.default"},
466+
"cache_built_engines": False,
467+
"reuse_cached_engines": False,
468+
"offload_module_to_cpu": True,
469+
}
470+
471+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
472+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
473+
model.cuda()
474+
torchtrt.save(trt_module, trt_ep_path)
475+
476+
deser_trt_module = torchtrt.load(trt_ep_path).module()
477+
outputs_pyt = model(input)
478+
outputs_trt = trt_module(input)
479+
480+
for idx in range(len(outputs_pyt)):
481+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
482+
assertions.assertTrue(
483+
cos_sim > COSINE_THRESHOLD,
484+
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
485+
)
486+
487+
outputs_trt_deser = deser_trt_module(input)
488+
for idx in range(len(outputs_pyt)):
489+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
490+
assertions.assertTrue(
491+
cos_sim > COSINE_THRESHOLD,
492+
msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
493+
)
494+
495+
384496
@pytest.mark.unit
385497
def test_arange_export(ir):
386498
"""

0 commit comments

Comments
 (0)