diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index d7c94d3b80..9aae901f87 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -101,7 +101,7 @@ def construct_refit_mapping_from_weight_name_map( params[w.split(".")[-1]] = state_dict[w].cuda() # Batch norm constant folding - scale, shift = batch_norm_constant_folding(**params, eps=1e-7) + scale, shift = batch_norm_constant_folding(**params, eps=1e-5) # Set scale to scale or shift to shift engine_weight_map[engine_weight_name] = eval( engine_weight_name.split(" ")[-1].lower() diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 44eb455ec2..d1ae28fb13 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -89,6 +89,70 @@ def test_mapping(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +@pytest.mark.unit +def test_conv_refit_with_weightmap(): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, x): + return self.conv(x) + + model = net().eval().to("cuda") + model2 = net().eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + min_block_size = 1 + use_python_runtime = True + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + min_block_size=min_block_size, + immutable_weights=False, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + verify_output=True, + ) + + # Check the output + model2.to("cuda") + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index a0af6420ed..b2caa2551b 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -317,9 +317,7 @@ def test_resnet18_modify_attribute(): mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) - mutable_module.conv1.weight = nn.Parameter( - torch.rand_like(mutable_module.conv1.weight) - ) + mutable_module.fc.weight = nn.Parameter(torch.rand_like(mutable_module.fc.weight)) assertions.assertEqual( mutable_module.refit_state.get_state(), RefitFlag.UNKNOWN,