3
3
import platform
4
4
import unittest
5
5
from importlib import metadata
6
+
6
7
import pytest
7
8
import timm
8
9
import torch
15
16
assertions = unittest .TestCase ()
16
17
import os
17
18
19
+
18
20
@pytest .mark .unit
19
21
def test_resnet18 (ir ):
20
22
model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
@@ -198,6 +200,96 @@ def test_resnet18_half(ir):
198
200
torch ._dynamo .reset ()
199
201
200
202
203
+ # @unittest.skipIf(
204
+ # torch.cuda.get_device_capability() < (10, 0),
205
+ # "FP4 quantization requires compute capability 10.0 or later",
206
+ # )
207
+ @unittest .skipIf (
208
+ not importlib .util .find_spec ("modelopt" ),
209
+ "ModelOpt is required to run this test" ,
210
+ )
211
+ @pytest .mark .unit
212
+ def test_base_fp4_dynamic_shapes (ir ):
213
+ import modelopt .torch .quantization as mtq
214
+ from modelopt .torch .quantization .utils import export_torch_mode
215
+
216
+ dtype = torch .float16
217
+
218
+ class SimpleNetwork (torch .nn .Module ):
219
+ def __init__ (self ):
220
+ super (SimpleNetwork , self ).__init__ ()
221
+ self .linear1 = torch .nn .Linear (
222
+ in_features = 64 , out_features = 32 , bias = True , dtype = dtype
223
+ )
224
+
225
+ def forward (self , x ):
226
+ x = self .linear1 (x )
227
+ return x
228
+
229
+ def calibrate_loop (model ):
230
+ """Simple calibration function for testing."""
231
+ model (dummy_inputs )
232
+
233
+ BATCH_SIZE = torch .export .Dim ("BATCH_SIZE" , min = 16 , max = 128 )
234
+ batch_size = 64
235
+ dummy_inputs = torch .ones (batch_size , 64 , dtype = dtype ).cuda ()
236
+
237
+ model = SimpleNetwork ().eval ().cuda ()
238
+ # model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
239
+ # model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
240
+
241
+ print (f"lan added model weight: { model .linear1 .weight = } " )
242
+ print (f"lan added model bias: { model .linear1 .bias = } " )
243
+
244
+ quant_cfg = mtq .NVFP4_DEFAULT_CFG
245
+ mtq .quantize (model , quant_cfg , forward_loop = calibrate_loop )
246
+ # model has qdq nodes at this point
247
+ with torch .no_grad ():
248
+ with export_torch_mode ():
249
+ exp_program = torch .export .export (
250
+ model , (dummy_inputs ,), strict = False , dynamic_shapes = ({0 : BATCH_SIZE },)
251
+ )
252
+
253
+ trt_model = torchtrt .dynamo .compile (
254
+ exp_program ,
255
+ inputs = [dummy_inputs ],
256
+ enabled_precisions = {
257
+ torch .float4_e2m1fn_x2 ,
258
+ torch .float8_e4m3fn ,
259
+ torch .float32 ,
260
+ torch .float16 ,
261
+ },
262
+ min_block_size = 1 ,
263
+ debug = True ,
264
+ cache_built_engines = False ,
265
+ reuse_cached_engines = False ,
266
+ use_explicit_typing = dtype == torch .float16 ,
267
+ )
268
+ batch_size = 128
269
+ input_tensor = torch .ones (batch_size , 64 , dtype = dtype ).cuda ()
270
+ expected_output = model (input_tensor )
271
+ outputs_trt = trt_model (input_tensor )
272
+ if os .getenv ("DISABLE_GEMM" , "false" ).lower () == "true" :
273
+ print ("lan added disable_gemm is set, compring result with weights" )
274
+ expected_output = model .linear1 .weight
275
+ else :
276
+ print ("lan added disable_gemm is not set, compring result with pytorch" )
277
+
278
+ print (
279
+ f"lan added torch_tensorrt outputs_trt: { outputs_trt = } { outputs_trt .dtype = } { outputs_trt .shape = } { outputs_trt .abs ().amax ()= } "
280
+ )
281
+ print (
282
+ f"lan added expected output_pyt: { expected_output = } { expected_output .dtype = } { expected_output .shape = } { expected_output .abs ().amax ()= } "
283
+ )
284
+
285
+ abs_diff = torch .abs (expected_output - outputs_trt )
286
+ print (
287
+ f"lan added max /mean abs_diff: { abs_diff .max ().item ()= } { abs_diff .mean ()= } "
288
+ )
289
+ print (f"lan added abs_diff: { abs_diff = } " )
290
+ assert torch .allclose (expected_output , outputs_trt , rtol = 0.8 , atol = 0.8 )
291
+
292
+
201
293
# @unittest.skipIf(
202
294
# torch.cuda.get_device_capability() < (10, 0),
203
295
# "FP4 quantization requires compute capability 10.0 or later",
@@ -210,6 +302,7 @@ def test_resnet18_half(ir):
210
302
def test_base_fp4 (ir ):
211
303
import modelopt .torch .quantization as mtq
212
304
from modelopt .torch .quantization .utils import export_torch_mode
305
+
213
306
dtype = torch .float16
214
307
215
308
class SimpleNetwork (torch .nn .Module ):
@@ -229,17 +322,16 @@ def calibrate_loop(model):
229
322
230
323
input_tensor = torch .ones (128 , 64 , dtype = dtype ).cuda ()
231
324
232
-
233
325
model = SimpleNetwork ().eval ().cuda ()
234
326
model .linear1 .weight = torch .nn .Parameter (torch .ones (32 , 64 , dtype = dtype ).cuda ())
235
327
model .linear1 .bias = torch .nn .Parameter (torch .zeros (128 , 32 , dtype = dtype ).cuda ())
236
328
print (f"lan added amax: { input_tensor .abs ().amax ()= } " )
237
329
print (f"lan added amax: { model .linear1 .weight .abs ().amax ()= } " )
238
330
expected_output = model (input_tensor )
239
- print (f"lan added model input: { input_tensor = } " )
331
+ print (f"lan added model input: { input_tensor = } " )
240
332
print (f"lan added model weight: { model .linear1 .weight = } " )
241
333
print (f"lan added model bias: { model .linear1 .bias = } " )
242
-
334
+
243
335
quant_cfg = mtq .NVFP4_DEFAULT_CFG
244
336
mtq .quantize (model , quant_cfg , forward_loop = calibrate_loop )
245
337
# model has qdq nodes at this point
@@ -271,11 +363,17 @@ def calibrate_loop(model):
271
363
else :
272
364
print ("lan added disable_gemm is not set, compring result with pytorch" )
273
365
274
- print (f"lan added torch_tensorrt outputs_trt: { outputs_trt = } { outputs_trt .dtype = } { outputs_trt .shape = } { outputs_trt .abs ().amax ()= } " )
275
- print (f"lan added expected output_pyt: { expected_output = } { expected_output .dtype = } { expected_output .shape = } { expected_output .abs ().amax ()= } " )
366
+ print (
367
+ f"lan added torch_tensorrt outputs_trt: { outputs_trt = } { outputs_trt .dtype = } { outputs_trt .shape = } { outputs_trt .abs ().amax ()= } "
368
+ )
369
+ print (
370
+ f"lan added expected output_pyt: { expected_output = } { expected_output .dtype = } { expected_output .shape = } { expected_output .abs ().amax ()= } "
371
+ )
276
372
277
373
abs_diff = torch .abs (expected_output - outputs_trt )
278
- print (f"lan added max /mean abs_diff: { abs_diff .max ().item ()= } { abs_diff .mean ()= } " )
374
+ print (
375
+ f"lan added max /mean abs_diff: { abs_diff .max ().item ()= } { abs_diff .mean ()= } "
376
+ )
279
377
print (f"lan added abs_diff: { abs_diff = } " )
280
378
assert torch .allclose (expected_output , outputs_trt , rtol = 0.8 , atol = 0.8 )
281
379
0 commit comments