Skip to content

Commit ea3e98a

Browse files
committed
feat(//py/torch_tensorrt/dynamo): Allow the refit system to cache complex numerics
1 parent 3e0c104 commit ea3e98a

File tree

3 files changed

+308
-8
lines changed

3 files changed

+308
-8
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from torch_tensorrt.dynamo.utils import (
4444
CPU_DEVICE,
4545
check_module_output,
46+
check_output_equal,
4647
get_model_device,
4748
get_torch_inputs,
4849
to_torch_device,
@@ -110,6 +111,17 @@ def construct_refit_mapping_from_weight_name_map(
110111
engine_weight_name.split(" ")[-1].lower()
111112
)
112113

114+
elif isinstance(sd_weight_name, tuple):
115+
# Buffer-slice mapping created by Stage 3 of _save_weight_mapping.
116+
# Encodes (state_dict_key, dim, index) for weights that are slices
117+
# of a source buffer (e.g. real/imag parts of an unpacked complex buffer).
118+
sd_key, dim, idx = sd_weight_name
119+
if sd_key not in state_dict:
120+
continue
121+
engine_weight_map[engine_weight_name] = (
122+
state_dict[sd_key].select(dim, idx).to(to_torch_device(settings.device))
123+
)
124+
113125
elif sd_weight_name not in state_dict:
114126
# If weights is not in sd, we can leave it unchanged
115127
continue
@@ -585,14 +597,31 @@ def refit_module_weights(
585597

586598
if verify_output and arg_inputs is not None:
587599
new_gm.to(to_torch_device(settings.device))
588-
if check_module_output(
589-
new_module=new_gm,
590-
refitted_module=compiled_module,
591-
arg_inputs=torch_inputs,
592-
kwarg_inputs=torch_kwarg_inputs,
593-
):
600+
# complex_graph_detection rewrites complex placeholders to real (view_as_real).
601+
# The compiled TRT module handles complex→real internally, but the lowered
602+
# PyTorch reference module (new_gm) expects real-unpacked inputs directly.
603+
has_complex_inputs = any(
604+
isinstance(x, torch.Tensor) and x.is_complex() for x in torch_inputs
605+
)
606+
if has_complex_inputs:
607+
lowered_inputs = [
608+
torch.view_as_real(x).contiguous()
609+
if isinstance(x, torch.Tensor) and x.is_complex()
610+
else x
611+
for x in torch_inputs
612+
]
613+
trt_outputs = compiled_module(*torch_inputs)
614+
ref_outputs = new_gm(*lowered_inputs, **torch_kwarg_inputs)
615+
outputs_match = check_output_equal(trt_outputs, ref_outputs)
616+
else:
617+
outputs_match = check_module_output(
618+
new_module=new_gm,
619+
refitted_module=compiled_module,
620+
arg_inputs=torch_inputs,
621+
kwarg_inputs=torch_kwarg_inputs,
622+
)
623+
if outputs_match:
594624
logger.info("Refitting Succeed!")
595-
new_gm.to(CPU_DEVICE)
596625
else:
597626
if weight_name_map:
598627
logger.warning(
@@ -608,7 +637,6 @@ def refit_module_weights(
608637
in_place=in_place,
609638
)
610639
logger.error("Refitting Failed! The outputs do not match.")
611-
new_gm.to(CPU_DEVICE)
612640
else:
613641
logger.info("Refitting Completed! Output verification skipped.")
614642

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,41 @@ def _save_weight_mapping(self) -> None:
587587
weight_refit_map[engine_weight_name].dtype,
588588
]
589589

590+
# Stage 3: Slice matching for unmatched non-scalar CONSTANT weights.
591+
# complex_graph_detection unpacks complex buffers to real:
592+
# freqs (S,D complex64) → freqs_unpacked_complex (S,D,2 float32)
593+
# The real and imag slices (freqs_unpacked_complex[...,0] and [...,1]) are
594+
# embedded as separate TRT constants, but their shapes differ from the source
595+
# buffer, so Stage 2 value matching fails. Here we try selecting each slice
596+
# along the last dimension of every sd entry to find the match.
597+
for engine_weight_name, val in list(weight_name_map.items()):
598+
if not isinstance(val, list) or len(val) != 2:
599+
continue
600+
sd_weight_name, dtype_val = val
601+
if sd_weight_name != "" or engine_weight_name not in weight_refit_map:
602+
continue
603+
ew_tensor = weight_refit_map[engine_weight_name].to(torch_device)
604+
if ew_tensor.numel() <= 1:
605+
continue # scalars are handled via constant_mapping
606+
matched = False
607+
for sd_key, sd_tensor in sd.items():
608+
if sd_tensor.dim() < 1 or sd_tensor.shape[-1] < 2:
609+
continue
610+
last_dim = sd_tensor.dim() - 1
611+
for idx in range(sd_tensor.shape[last_dim]):
612+
sd_slice = sd_tensor.select(last_dim, idx)
613+
if TRTInterpreter.check_weight_equal(
614+
sd_slice, ew_tensor, torch_device
615+
):
616+
weight_name_map[engine_weight_name] = [
617+
(sd_key, last_dim, idx),
618+
dtype_val,
619+
]
620+
matched = True
621+
break
622+
if matched:
623+
break
624+
590625
weight_name_map["constant_mapping"] = constant_mapping
591626
self.weight_name_map = weight_name_map
592627

tests/py/dynamo/models/test_model_refit.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,10 @@ def test_refit_one_engine_bert_with_weightmap():
521521
torch._dynamo.reset()
522522

523523

524+
@unittest.skipIf(
525+
not importlib.util.find_spec("torchvision"),
526+
"torchvision is not installed",
527+
)
524528
@unittest.skipIf(
525529
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
526530
"TorchScript Frontend is not available",
@@ -577,6 +581,10 @@ def test_refit_one_engine_inline_runtime_with_weightmap(tmpdir):
577581
torch._dynamo.reset()
578582

579583

584+
@unittest.skipIf(
585+
not importlib.util.find_spec("torchvision"),
586+
"torchvision is not installed",
587+
)
580588
@unittest.skipIf(
581589
not torch_trt.ENABLED_FEATURES.refit,
582590
"Refit feature is not supported in Python 3.13 or higher",
@@ -764,6 +772,10 @@ def forward(self, x):
764772
torch._dynamo.reset()
765773

766774

775+
@unittest.skipIf(
776+
not importlib.util.find_spec("torchvision"),
777+
"torchvision is not installed",
778+
)
767779
@unittest.skipIf(
768780
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
769781
"TorchScript Frontend is not available",
@@ -879,6 +891,10 @@ def test_refit_one_engine_bert_without_weightmap():
879891
torch._dynamo.reset()
880892

881893

894+
@unittest.skipIf(
895+
not importlib.util.find_spec("torchvision"),
896+
"torchvision is not installed",
897+
)
882898
@unittest.skipIf(
883899
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
884900
"TorchScript Frontend is not available",
@@ -932,6 +948,10 @@ def test_refit_one_engine_inline_runtime_without_weightmap(tmpdir):
932948
torch._dynamo.reset()
933949

934950

951+
@unittest.skipIf(
952+
not importlib.util.find_spec("torchvision"),
953+
"torchvision is not installed",
954+
)
935955
@unittest.skipIf(
936956
not torch_trt.ENABLED_FEATURES.refit,
937957
"Refit feature is not supported in Python 3.13 or higher",
@@ -1107,3 +1127,220 @@ def forward(self, x):
11071127
# Clean up model env
11081128

11091129
torch._dynamo.reset()
1130+
1131+
1132+
@unittest.skipIf(
1133+
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
1134+
"TorchScript Frontend is not available",
1135+
)
1136+
@unittest.skipIf(
1137+
not torch_trt.ENABLED_FEATURES.refit,
1138+
"Refit feature is not supported in Python 3.13 or higher",
1139+
)
1140+
@pytest.mark.unit
1141+
def test_complex_buffer_refit():
1142+
"""Refit a model whose weights include a complex-valued buffer (e.g. RoPE freqs).
1143+
1144+
Exercises the combined complex_graph_detection + refit_module_weights path:
1145+
- complex get_attr buffer is unpacked to real by the lowering pass
1146+
- complex placeholder input goes through view_as_real at the TRT boundary
1147+
- after refitting with new frequencies the output matches the new model
1148+
"""
1149+
1150+
class ComplexFreqModel(nn.Module):
1151+
def __init__(self, freqs: torch.Tensor):
1152+
super().__init__()
1153+
self.register_buffer("freqs", freqs.cuda())
1154+
1155+
def forward(self, z: torch.Tensor) -> torch.Tensor:
1156+
# complex mul then expose as real so TRT can produce a real output
1157+
return torch.view_as_real(z * self.freqs)
1158+
1159+
SEQ, DIM = 8, 32
1160+
1161+
def make_freqs() -> torch.Tensor:
1162+
angles = torch.rand(SEQ, DIM // 2)
1163+
return torch.polar(torch.ones_like(angles), angles).cuda()
1164+
1165+
freqs1 = make_freqs()
1166+
freqs2 = make_freqs()
1167+
1168+
model1 = ComplexFreqModel(freqs1).eval()
1169+
model2 = ComplexFreqModel(freqs2).eval()
1170+
1171+
z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda()
1172+
inputs = [z]
1173+
1174+
exp_program1 = torch.export.export(model1, tuple(inputs))
1175+
exp_program2 = torch.export.export(model2, tuple(inputs))
1176+
1177+
trt_gm = torchtrt.dynamo.compile(
1178+
exp_program1,
1179+
tuple(inputs),
1180+
use_python_runtime=True,
1181+
enabled_precisions={torch.float},
1182+
min_block_size=1,
1183+
immutable_weights=False,
1184+
)
1185+
1186+
new_trt_gm = refit_module_weights(
1187+
compiled_module=trt_gm,
1188+
new_weight_module=exp_program2,
1189+
arg_inputs=inputs,
1190+
use_weight_map_cache=True,
1191+
verify_output=True,
1192+
)
1193+
1194+
expected_output = exp_program2.module()(*inputs)
1195+
refitted_output = new_trt_gm(*inputs)
1196+
1197+
assertions.assertTrue(
1198+
torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2),
1199+
"Refit with complex buffer failed: output mismatch after refit",
1200+
)
1201+
1202+
torch._dynamo.reset()
1203+
1204+
1205+
@unittest.skipIf(
1206+
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
1207+
"TorchScript Frontend is not available",
1208+
)
1209+
@unittest.skipIf(
1210+
not torch_trt.ENABLED_FEATURES.refit,
1211+
"Refit feature is not supported in Python 3.13 or higher",
1212+
)
1213+
@pytest.mark.unit
1214+
def test_complex_buffer_with_real_param_refit():
1215+
"""Refit a model that mixes a complex buffer with a real nn.Linear weight.
1216+
1217+
Verifies that Stage 3 slice-matching for complex buffer constants coexists
1218+
correctly with ordinary weight-name-map entries for real parameters.
1219+
After refitting both the frequencies and the projection matrix, the output
1220+
should match the new model exactly.
1221+
"""
1222+
1223+
SEQ, DIM = 8, 32
1224+
1225+
class ComplexRotateAndProject(nn.Module):
1226+
def __init__(self, freqs: torch.Tensor):
1227+
super().__init__()
1228+
self.register_buffer("freqs", freqs.cuda())
1229+
self.proj = nn.Linear(DIM, DIM, bias=False)
1230+
1231+
def forward(self, z: torch.Tensor) -> torch.Tensor:
1232+
rotated = z * self.freqs # complex mul, (SEQ, DIM//2)
1233+
r = torch.view_as_real(rotated) # (SEQ, DIM//2, 2)
1234+
flat = r.reshape(z.shape[0], -1) # (SEQ, DIM)
1235+
return self.proj(flat) # (SEQ, DIM) real output
1236+
1237+
def make_freqs() -> torch.Tensor:
1238+
angles = torch.rand(SEQ, DIM // 2)
1239+
return torch.polar(torch.ones_like(angles), angles).cuda()
1240+
1241+
model1 = ComplexRotateAndProject(make_freqs()).eval().cuda()
1242+
model2 = ComplexRotateAndProject(make_freqs()).eval().cuda()
1243+
1244+
z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda()
1245+
inputs = [z]
1246+
1247+
exp_program1 = torch.export.export(model1, tuple(inputs))
1248+
exp_program2 = torch.export.export(model2, tuple(inputs))
1249+
1250+
trt_gm = torchtrt.dynamo.compile(
1251+
exp_program1,
1252+
tuple(inputs),
1253+
use_python_runtime=True,
1254+
enabled_precisions={torch.float},
1255+
min_block_size=1,
1256+
immutable_weights=False,
1257+
)
1258+
1259+
new_trt_gm = refit_module_weights(
1260+
compiled_module=trt_gm,
1261+
new_weight_module=exp_program2,
1262+
arg_inputs=inputs,
1263+
use_weight_map_cache=True,
1264+
verify_output=True,
1265+
)
1266+
1267+
expected_output = exp_program2.module()(*inputs)
1268+
refitted_output = new_trt_gm(*inputs)
1269+
1270+
assertions.assertTrue(
1271+
torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2),
1272+
"Refit with complex buffer + real param failed: output mismatch",
1273+
)
1274+
1275+
torch._dynamo.reset()
1276+
1277+
1278+
@unittest.skipIf(
1279+
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
1280+
"TorchScript Frontend is not available",
1281+
)
1282+
@unittest.skipIf(
1283+
not torch_trt.ENABLED_FEATURES.refit,
1284+
"Refit feature is not supported in Python 3.13 or higher",
1285+
)
1286+
@pytest.mark.unit
1287+
def test_dual_complex_buffer_refit():
1288+
"""Refit a model with two independent complex buffers.
1289+
1290+
Ensures Stage 3 value-based matching correctly distinguishes the real and
1291+
imaginary slices of freqs_a from those of freqs_b when both are unpacked to
1292+
separate _unpacked_complex state-dict entries with the same shape.
1293+
"""
1294+
1295+
SEQ, DIM = 8, 32
1296+
1297+
class DualComplexFreqModel(nn.Module):
1298+
def __init__(self, freqs_a: torch.Tensor, freqs_b: torch.Tensor):
1299+
super().__init__()
1300+
self.register_buffer("freqs_a", freqs_a.cuda())
1301+
self.register_buffer("freqs_b", freqs_b.cuda())
1302+
1303+
def forward(self, z: torch.Tensor) -> torch.Tensor:
1304+
ra = torch.view_as_real(z * self.freqs_a) # (SEQ, DIM//2, 2)
1305+
rb = torch.view_as_real(z * self.freqs_b) # (SEQ, DIM//2, 2)
1306+
return ra + rb # real output
1307+
1308+
def make_freqs() -> torch.Tensor:
1309+
angles = torch.rand(SEQ, DIM // 2)
1310+
return torch.polar(torch.ones_like(angles), angles).cuda()
1311+
1312+
model1 = DualComplexFreqModel(make_freqs(), make_freqs()).eval()
1313+
model2 = DualComplexFreqModel(make_freqs(), make_freqs()).eval()
1314+
1315+
z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda()
1316+
inputs = [z]
1317+
1318+
exp_program1 = torch.export.export(model1, tuple(inputs))
1319+
exp_program2 = torch.export.export(model2, tuple(inputs))
1320+
1321+
trt_gm = torchtrt.dynamo.compile(
1322+
exp_program1,
1323+
tuple(inputs),
1324+
use_python_runtime=True,
1325+
enabled_precisions={torch.float},
1326+
min_block_size=1,
1327+
immutable_weights=False,
1328+
)
1329+
1330+
new_trt_gm = refit_module_weights(
1331+
compiled_module=trt_gm,
1332+
new_weight_module=exp_program2,
1333+
arg_inputs=inputs,
1334+
use_weight_map_cache=True,
1335+
verify_output=True,
1336+
)
1337+
1338+
expected_output = exp_program2.module()(*inputs)
1339+
refitted_output = new_trt_gm(*inputs)
1340+
1341+
assertions.assertTrue(
1342+
torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2),
1343+
"Refit with dual complex buffers failed: output mismatch",
1344+
)
1345+
1346+
torch._dynamo.reset()

0 commit comments

Comments
 (0)