66
77import numpy as np
88import torch
9+ import torch_tensorrt
910from torch_tensorrt ._Device import Device
1011from torch_tensorrt .dynamo import _defaults
1112from torch_tensorrt .dynamo ._compiler import compile as dynamo_compile
@@ -61,6 +62,7 @@ def __init__(
6162 * ,
6263 device : Optional [Union [Device , torch .device , str ]] = _defaults .DEVICE ,
6364 use_python_runtime : bool = _defaults .USE_PYTHON_RUNTIME ,
65+ enable_cuda_graph : bool = True ,
6466 immutable_weights : bool = False ,
6567 strict : bool = True ,
6668 allow_complex_guards_as_runtime_asserts : bool = False ,
@@ -127,6 +129,7 @@ def __init__(
127129 self .arg_inputs : tuple [Any , ...] = tuple ()
128130 self .kwarg_inputs : dict [str , Any ] = {}
129131 self .additional_settings = kwargs
132+ self .enable_cuda_graph = enable_cuda_graph
130133 self .strict = strict
131134 self .allow_complex_guards_as_runtime_asserts = (
132135 allow_complex_guards_as_runtime_asserts
@@ -142,7 +145,11 @@ def __init__(
142145 self .run_info : Optional [tuple [Any , ...]] = None
143146 self .state_dict_metadata : dict [str , torch .Size ] = {}
144147 self ._store_state_dict_metadata ()
145-
148+ self .enable_weight_streaming = (
149+ kwargs ["enable_weight_streaming" ]
150+ if "enable_weight_streaming" in kwargs
151+ else False
152+ )
146153 cls = self .__class__
147154 self .__class__ = type (
148155 self .original_model .__class__ .__name__ ,
@@ -193,7 +200,7 @@ def forward(a, b, c=0, d=0):
193200
194201 self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
195202
196- def _get_total_dynamic_shapes (self ) -> Union [ dict [str , Any ], None ] :
203+ def _get_total_dynamic_shapes (self ) -> dict [str , Any ] | None :
197204 if not self .arg_dynamic_shapes and not self .kwarg_dynamic_shapes :
198205 return None
199206 total_dynamic_shape = {}
@@ -266,15 +273,17 @@ def refit_gm(self) -> None:
266273 MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module.
267274 If it fails to catch the changes, please call this function manually to update the TRT graph module.
268275 """
269- self . original_model . to ( to_torch_device ( self . trt_device ))
276+
270277 if self .exp_program is None :
278+ self .original_model .to (to_torch_device (self .trt_device ))
271279 self .exp_program = self .get_exported_program ()
272280 else :
273281 self .exp_program ._state_dict = (
274282 MutableTorchTensorRTModule ._transform_state_dict (
275283 self .original_model .state_dict ()
276284 )
277285 )
286+ self .exp_program .module ().to (to_torch_device (self .trt_device ))
278287 self .gm = refit_module_weights (
279288 self .gm ,
280289 self .exp_program ,
@@ -284,7 +293,7 @@ def refit_gm(self) -> None:
284293 in_place = True ,
285294 )
286295
287- self .original_model .cpu ( )
296+ self .original_model .to ( "cpu" )
288297 torch .cuda .empty_cache ()
289298
290299 def get_exported_program (self ) -> torch .export .ExportedProgram :
@@ -324,8 +333,15 @@ def compile(self) -> None:
324333 use_python_runtime = self .use_python_runtime ,
325334 ** self .additional_settings ,
326335 )
327- self .original_model .cpu ( )
336+ self .original_model .to ( "cpu" )
328337 torch .cuda .empty_cache ()
338+ # torch_tensorrt.runtime.set_cudagraphs_mode(self.enable_cuda_graph)
339+ # if self.enable_cuda_graph:
340+ # self.gm = torch_tensorrt.runtime.enable_cudagraphs(self.gm)
341+ if self .enable_weight_streaming :
342+ self .weight_streaming_ctx = torch_tensorrt .runtime .weight_streaming (self .gm )
343+ requested_budget = int (16 * 2 << 20 )
344+ self .weight_streaming_ctx .device_budget = requested_budget
329345
330346 def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
331347
@@ -446,14 +462,21 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
446462 self ._store_state_dict_metadata ()
447463 self .refit_state .set_state (RefitFlag .LIVE )
448464
465+ # weight_streaming_ctx = self.weight_streaming_ctx if self.enable_weight_streaming else None
449466 result = self .gm (* args , ** kwargs )
450467 # Storing inputs and outputs for verification when the state is unknown
451468 self .run_info = (args , kwargs , result )
452469 return result
453470
454- def to (self , device : str ) -> None :
455- logger .warning ("Original PyTorch model is moved. CPU offload may failed." )
456- self .original_model .to (device )
471+ def to (self , * args : Any , ** kwargs : Any ) -> None :
472+ logger .warning (
473+ "Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage."
474+ + "If this is absolute necessary, please call module.pytorch_model.to(...)"
475+ )
476+
477+ @property
478+ def device (self ) -> torch .device :
479+ return to_torch_device (self .trt_device )
457480
458481 def __deepcopy__ (self , memo : Any ) -> Any :
459482 cls = self .__class__
0 commit comments