Skip to content

Commit bbd6d97

Browse files
committed
Merge branch 'lluo/fp4_issue_debugging' into lluo/flux_fp4
2 parents 9f803ee + f989864 commit bbd6d97

File tree

2 files changed

+108
-7
lines changed

2 files changed

+108
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,10 @@ def aten_ops_quantize_op(
629629
)
630630
else:
631631

632-
@dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default)
632+
@dynamo_tensorrt_converter(
633+
torch.ops.tensorrt.dynamic_block_quantize_op.default,
634+
supports_dynamic_shapes=True,
635+
)
633636
def aten_ops_dynamic_block_quantize_op(
634637
ctx: ConversionContext,
635638
target: Target,

tests/py/dynamo/models/test_models_export.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import platform
44
import unittest
55
from importlib import metadata
6+
67
import pytest
78
import timm
89
import torch
@@ -15,6 +16,7 @@
1516
assertions = unittest.TestCase()
1617
import os
1718

19+
1820
@pytest.mark.unit
1921
def test_resnet18(ir):
2022
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -198,6 +200,96 @@ def test_resnet18_half(ir):
198200
torch._dynamo.reset()
199201

200202

203+
# @unittest.skipIf(
204+
# torch.cuda.get_device_capability() < (10, 0),
205+
# "FP4 quantization requires compute capability 10.0 or later",
206+
# )
207+
@unittest.skipIf(
208+
not importlib.util.find_spec("modelopt"),
209+
"ModelOpt is required to run this test",
210+
)
211+
@pytest.mark.unit
212+
def test_base_fp4_dynamic_shapes(ir):
213+
import modelopt.torch.quantization as mtq
214+
from modelopt.torch.quantization.utils import export_torch_mode
215+
216+
dtype = torch.float16
217+
218+
class SimpleNetwork(torch.nn.Module):
219+
def __init__(self):
220+
super(SimpleNetwork, self).__init__()
221+
self.linear1 = torch.nn.Linear(
222+
in_features=64, out_features=32, bias=True, dtype=dtype
223+
)
224+
225+
def forward(self, x):
226+
x = self.linear1(x)
227+
return x
228+
229+
def calibrate_loop(model):
230+
"""Simple calibration function for testing."""
231+
model(dummy_inputs)
232+
233+
BATCH_SIZE = torch.export.Dim("BATCH_SIZE", min=16, max=128)
234+
batch_size = 64
235+
dummy_inputs = torch.ones(batch_size, 64, dtype=dtype).cuda()
236+
237+
model = SimpleNetwork().eval().cuda()
238+
# model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
239+
# model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
240+
241+
print(f"lan added model weight: {model.linear1.weight=}")
242+
print(f"lan added model bias: {model.linear1.bias=}")
243+
244+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
245+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
246+
# model has qdq nodes at this point
247+
with torch.no_grad():
248+
with export_torch_mode():
249+
exp_program = torch.export.export(
250+
model, (dummy_inputs,), strict=False, dynamic_shapes=({0: BATCH_SIZE},)
251+
)
252+
253+
trt_model = torchtrt.dynamo.compile(
254+
exp_program,
255+
inputs=[dummy_inputs],
256+
enabled_precisions={
257+
torch.float4_e2m1fn_x2,
258+
torch.float8_e4m3fn,
259+
torch.float32,
260+
torch.float16,
261+
},
262+
min_block_size=1,
263+
debug=True,
264+
cache_built_engines=False,
265+
reuse_cached_engines=False,
266+
use_explicit_typing=dtype == torch.float16,
267+
)
268+
batch_size = 128
269+
input_tensor = torch.ones(batch_size, 64, dtype=dtype).cuda()
270+
expected_output = model(input_tensor)
271+
outputs_trt = trt_model(input_tensor)
272+
if os.getenv("DISABLE_GEMM", "false").lower() == "true":
273+
print("lan added disable_gemm is set, compring result with weights")
274+
expected_output = model.linear1.weight
275+
else:
276+
print("lan added disable_gemm is not set, compring result with pytorch")
277+
278+
print(
279+
f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}"
280+
)
281+
print(
282+
f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}"
283+
)
284+
285+
abs_diff = torch.abs(expected_output - outputs_trt)
286+
print(
287+
f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
288+
)
289+
print(f"lan added abs_diff: {abs_diff=}")
290+
assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)
291+
292+
201293
# @unittest.skipIf(
202294
# torch.cuda.get_device_capability() < (10, 0),
203295
# "FP4 quantization requires compute capability 10.0 or later",
@@ -210,6 +302,7 @@ def test_resnet18_half(ir):
210302
def test_base_fp4(ir):
211303
import modelopt.torch.quantization as mtq
212304
from modelopt.torch.quantization.utils import export_torch_mode
305+
213306
dtype = torch.float16
214307

215308
class SimpleNetwork(torch.nn.Module):
@@ -229,17 +322,16 @@ def calibrate_loop(model):
229322

230323
input_tensor = torch.ones(128, 64, dtype=dtype).cuda()
231324

232-
233325
model = SimpleNetwork().eval().cuda()
234326
model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
235327
model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
236328
print(f"lan added amax: {input_tensor.abs().amax()=}")
237329
print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
238330
expected_output = model(input_tensor)
239-
print(f"lan added model input: {input_tensor=}")
331+
print(f"lan added model input: {input_tensor=}")
240332
print(f"lan added model weight: {model.linear1.weight=}")
241333
print(f"lan added model bias: {model.linear1.bias=}")
242-
334+
243335
quant_cfg = mtq.NVFP4_DEFAULT_CFG
244336
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
245337
# model has qdq nodes at this point
@@ -271,11 +363,17 @@ def calibrate_loop(model):
271363
else:
272364
print("lan added disable_gemm is not set, compring result with pytorch")
273365

274-
print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}")
275-
print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}")
366+
print(
367+
f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}"
368+
)
369+
print(
370+
f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}"
371+
)
276372

277373
abs_diff = torch.abs(expected_output - outputs_trt)
278-
print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
374+
print(
375+
f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
376+
)
279377
print(f"lan added abs_diff: {abs_diff=}")
280378
assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)
281379

0 commit comments

Comments
 (0)