@@ -521,6 +521,10 @@ def test_refit_one_engine_bert_with_weightmap():
521521 torch ._dynamo .reset ()
522522
523523
524+ @unittest .skipIf (
525+ not importlib .util .find_spec ("torchvision" ),
526+ "torchvision is not installed" ,
527+ )
524528@unittest .skipIf (
525529 not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
526530 "TorchScript Frontend is not available" ,
@@ -577,6 +581,10 @@ def test_refit_one_engine_inline_runtime_with_weightmap(tmpdir):
577581 torch ._dynamo .reset ()
578582
579583
584+ @unittest .skipIf (
585+ not importlib .util .find_spec ("torchvision" ),
586+ "torchvision is not installed" ,
587+ )
580588@unittest .skipIf (
581589 not torch_trt .ENABLED_FEATURES .refit ,
582590 "Refit feature is not supported in Python 3.13 or higher" ,
@@ -764,6 +772,10 @@ def forward(self, x):
764772 torch ._dynamo .reset ()
765773
766774
775+ @unittest .skipIf (
776+ not importlib .util .find_spec ("torchvision" ),
777+ "torchvision is not installed" ,
778+ )
767779@unittest .skipIf (
768780 not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
769781 "TorchScript Frontend is not available" ,
@@ -879,6 +891,10 @@ def test_refit_one_engine_bert_without_weightmap():
879891 torch ._dynamo .reset ()
880892
881893
894+ @unittest .skipIf (
895+ not importlib .util .find_spec ("torchvision" ),
896+ "torchvision is not installed" ,
897+ )
882898@unittest .skipIf (
883899 not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
884900 "TorchScript Frontend is not available" ,
@@ -932,6 +948,10 @@ def test_refit_one_engine_inline_runtime_without_weightmap(tmpdir):
932948 torch ._dynamo .reset ()
933949
934950
951+ @unittest .skipIf (
952+ not importlib .util .find_spec ("torchvision" ),
953+ "torchvision is not installed" ,
954+ )
935955@unittest .skipIf (
936956 not torch_trt .ENABLED_FEATURES .refit ,
937957 "Refit feature is not supported in Python 3.13 or higher" ,
@@ -1107,3 +1127,220 @@ def forward(self, x):
11071127 # Clean up model env
11081128
11091129 torch ._dynamo .reset ()
1130+
1131+
1132+ @unittest .skipIf (
1133+ not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
1134+ "TorchScript Frontend is not available" ,
1135+ )
1136+ @unittest .skipIf (
1137+ not torch_trt .ENABLED_FEATURES .refit ,
1138+ "Refit feature is not supported in Python 3.13 or higher" ,
1139+ )
1140+ @pytest .mark .unit
1141+ def test_complex_buffer_refit ():
1142+ """Refit a model whose weights include a complex-valued buffer (e.g. RoPE freqs).
1143+
1144+ Exercises the combined complex_graph_detection + refit_module_weights path:
1145+ - complex get_attr buffer is unpacked to real by the lowering pass
1146+ - complex placeholder input goes through view_as_real at the TRT boundary
1147+ - after refitting with new frequencies the output matches the new model
1148+ """
1149+
1150+ class ComplexFreqModel (nn .Module ):
1151+ def __init__ (self , freqs : torch .Tensor ):
1152+ super ().__init__ ()
1153+ self .register_buffer ("freqs" , freqs .cuda ())
1154+
1155+ def forward (self , z : torch .Tensor ) -> torch .Tensor :
1156+ # complex mul then expose as real so TRT can produce a real output
1157+ return torch .view_as_real (z * self .freqs )
1158+
1159+ SEQ , DIM = 8 , 32
1160+
1161+ def make_freqs () -> torch .Tensor :
1162+ angles = torch .rand (SEQ , DIM // 2 )
1163+ return torch .polar (torch .ones_like (angles ), angles ).cuda ()
1164+
1165+ freqs1 = make_freqs ()
1166+ freqs2 = make_freqs ()
1167+
1168+ model1 = ComplexFreqModel (freqs1 ).eval ()
1169+ model2 = ComplexFreqModel (freqs2 ).eval ()
1170+
1171+ z = torch .randn (SEQ , DIM // 2 , dtype = torch .complex64 ).cuda ()
1172+ inputs = [z ]
1173+
1174+ exp_program1 = torch .export .export (model1 , tuple (inputs ))
1175+ exp_program2 = torch .export .export (model2 , tuple (inputs ))
1176+
1177+ trt_gm = torchtrt .dynamo .compile (
1178+ exp_program1 ,
1179+ tuple (inputs ),
1180+ use_python_runtime = True ,
1181+ enabled_precisions = {torch .float },
1182+ min_block_size = 1 ,
1183+ immutable_weights = False ,
1184+ )
1185+
1186+ new_trt_gm = refit_module_weights (
1187+ compiled_module = trt_gm ,
1188+ new_weight_module = exp_program2 ,
1189+ arg_inputs = inputs ,
1190+ use_weight_map_cache = True ,
1191+ verify_output = True ,
1192+ )
1193+
1194+ expected_output = exp_program2 .module ()(* inputs )
1195+ refitted_output = new_trt_gm (* inputs )
1196+
1197+ assertions .assertTrue (
1198+ torch .allclose (expected_output , refitted_output , atol = 1e-2 , rtol = 1e-2 ),
1199+ "Refit with complex buffer failed: output mismatch after refit" ,
1200+ )
1201+
1202+ torch ._dynamo .reset ()
1203+
1204+
1205+ @unittest .skipIf (
1206+ not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
1207+ "TorchScript Frontend is not available" ,
1208+ )
1209+ @unittest .skipIf (
1210+ not torch_trt .ENABLED_FEATURES .refit ,
1211+ "Refit feature is not supported in Python 3.13 or higher" ,
1212+ )
1213+ @pytest .mark .unit
1214+ def test_complex_buffer_with_real_param_refit ():
1215+ """Refit a model that mixes a complex buffer with a real nn.Linear weight.
1216+
1217+ Verifies that Stage 3 slice-matching for complex buffer constants coexists
1218+ correctly with ordinary weight-name-map entries for real parameters.
1219+ After refitting both the frequencies and the projection matrix, the output
1220+ should match the new model exactly.
1221+ """
1222+
1223+ SEQ , DIM = 8 , 32
1224+
1225+ class ComplexRotateAndProject (nn .Module ):
1226+ def __init__ (self , freqs : torch .Tensor ):
1227+ super ().__init__ ()
1228+ self .register_buffer ("freqs" , freqs .cuda ())
1229+ self .proj = nn .Linear (DIM , DIM , bias = False )
1230+
1231+ def forward (self , z : torch .Tensor ) -> torch .Tensor :
1232+ rotated = z * self .freqs # complex mul, (SEQ, DIM//2)
1233+ r = torch .view_as_real (rotated ) # (SEQ, DIM//2, 2)
1234+ flat = r .reshape (z .shape [0 ], - 1 ) # (SEQ, DIM)
1235+ return self .proj (flat ) # (SEQ, DIM) real output
1236+
1237+ def make_freqs () -> torch .Tensor :
1238+ angles = torch .rand (SEQ , DIM // 2 )
1239+ return torch .polar (torch .ones_like (angles ), angles ).cuda ()
1240+
1241+ model1 = ComplexRotateAndProject (make_freqs ()).eval ().cuda ()
1242+ model2 = ComplexRotateAndProject (make_freqs ()).eval ().cuda ()
1243+
1244+ z = torch .randn (SEQ , DIM // 2 , dtype = torch .complex64 ).cuda ()
1245+ inputs = [z ]
1246+
1247+ exp_program1 = torch .export .export (model1 , tuple (inputs ))
1248+ exp_program2 = torch .export .export (model2 , tuple (inputs ))
1249+
1250+ trt_gm = torchtrt .dynamo .compile (
1251+ exp_program1 ,
1252+ tuple (inputs ),
1253+ use_python_runtime = True ,
1254+ enabled_precisions = {torch .float },
1255+ min_block_size = 1 ,
1256+ immutable_weights = False ,
1257+ )
1258+
1259+ new_trt_gm = refit_module_weights (
1260+ compiled_module = trt_gm ,
1261+ new_weight_module = exp_program2 ,
1262+ arg_inputs = inputs ,
1263+ use_weight_map_cache = True ,
1264+ verify_output = True ,
1265+ )
1266+
1267+ expected_output = exp_program2 .module ()(* inputs )
1268+ refitted_output = new_trt_gm (* inputs )
1269+
1270+ assertions .assertTrue (
1271+ torch .allclose (expected_output , refitted_output , atol = 1e-2 , rtol = 1e-2 ),
1272+ "Refit with complex buffer + real param failed: output mismatch" ,
1273+ )
1274+
1275+ torch ._dynamo .reset ()
1276+
1277+
1278+ @unittest .skipIf (
1279+ not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
1280+ "TorchScript Frontend is not available" ,
1281+ )
1282+ @unittest .skipIf (
1283+ not torch_trt .ENABLED_FEATURES .refit ,
1284+ "Refit feature is not supported in Python 3.13 or higher" ,
1285+ )
1286+ @pytest .mark .unit
1287+ def test_dual_complex_buffer_refit ():
1288+ """Refit a model with two independent complex buffers.
1289+
1290+ Ensures Stage 3 value-based matching correctly distinguishes the real and
1291+ imaginary slices of freqs_a from those of freqs_b when both are unpacked to
1292+ separate _unpacked_complex state-dict entries with the same shape.
1293+ """
1294+
1295+ SEQ , DIM = 8 , 32
1296+
1297+ class DualComplexFreqModel (nn .Module ):
1298+ def __init__ (self , freqs_a : torch .Tensor , freqs_b : torch .Tensor ):
1299+ super ().__init__ ()
1300+ self .register_buffer ("freqs_a" , freqs_a .cuda ())
1301+ self .register_buffer ("freqs_b" , freqs_b .cuda ())
1302+
1303+ def forward (self , z : torch .Tensor ) -> torch .Tensor :
1304+ ra = torch .view_as_real (z * self .freqs_a ) # (SEQ, DIM//2, 2)
1305+ rb = torch .view_as_real (z * self .freqs_b ) # (SEQ, DIM//2, 2)
1306+ return ra + rb # real output
1307+
1308+ def make_freqs () -> torch .Tensor :
1309+ angles = torch .rand (SEQ , DIM // 2 )
1310+ return torch .polar (torch .ones_like (angles ), angles ).cuda ()
1311+
1312+ model1 = DualComplexFreqModel (make_freqs (), make_freqs ()).eval ()
1313+ model2 = DualComplexFreqModel (make_freqs (), make_freqs ()).eval ()
1314+
1315+ z = torch .randn (SEQ , DIM // 2 , dtype = torch .complex64 ).cuda ()
1316+ inputs = [z ]
1317+
1318+ exp_program1 = torch .export .export (model1 , tuple (inputs ))
1319+ exp_program2 = torch .export .export (model2 , tuple (inputs ))
1320+
1321+ trt_gm = torchtrt .dynamo .compile (
1322+ exp_program1 ,
1323+ tuple (inputs ),
1324+ use_python_runtime = True ,
1325+ enabled_precisions = {torch .float },
1326+ min_block_size = 1 ,
1327+ immutable_weights = False ,
1328+ )
1329+
1330+ new_trt_gm = refit_module_weights (
1331+ compiled_module = trt_gm ,
1332+ new_weight_module = exp_program2 ,
1333+ arg_inputs = inputs ,
1334+ use_weight_map_cache = True ,
1335+ verify_output = True ,
1336+ )
1337+
1338+ expected_output = exp_program2 .module ()(* inputs )
1339+ refitted_output = new_trt_gm (* inputs )
1340+
1341+ assertions .assertTrue (
1342+ torch .allclose (expected_output , refitted_output , atol = 1e-2 , rtol = 1e-2 ),
1343+ "Refit with dual complex buffers failed: output mismatch" ,
1344+ )
1345+
1346+ torch ._dynamo .reset ()
0 commit comments