Skip to content

Commit 8deb1f9

Browse files
committed
chore: add modelopt tests file
1 parent e923627 commit 8deb1f9

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# type: ignore
2+
import importlib
3+
import platform
4+
import unittest
5+
from importlib import metadata
6+
7+
import pytest
8+
import torch
9+
import torch_tensorrt as torchtrt
10+
11+
from packaging.version import Version
12+
13+
assertions = unittest.TestCase()
14+
15+
16+
@unittest.skipIf(
17+
torch.cuda.get_device_capability() < (8, 9),
18+
"FP8 quantization requires compute capability 8.9 or later",
19+
)
20+
@unittest.skipIf(
21+
not importlib.util.find_spec("modelopt"),
22+
"ModelOpt is required to run this test",
23+
)
24+
@pytest.mark.unit
25+
def test_base_fp8():
26+
import modelopt.torch.quantization as mtq
27+
from modelopt.torch.quantization.utils import export_torch_mode
28+
29+
class SimpleNetwork(torch.nn.Module):
30+
def __init__(self):
31+
super(SimpleNetwork, self).__init__()
32+
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
33+
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
34+
35+
def forward(self, x):
36+
x = self.linear1(x)
37+
x = torch.nn.ReLU()(x)
38+
x = self.linear2(x)
39+
return x
40+
41+
def calibrate_loop(model):
42+
"""Simple calibration function for testing."""
43+
model(input_tensor)
44+
45+
input_tensor = torch.randn(1, 10).cuda()
46+
model = SimpleNetwork().eval().cuda()
47+
48+
quant_cfg = mtq.FP8_DEFAULT_CFG
49+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
50+
# model has FP8 qdq nodes at this point
51+
output_pyt = model(input_tensor)
52+
53+
with torch.no_grad():
54+
with export_torch_mode():
55+
exp_program = torch.export.export(model, (input_tensor,), strict=False)
56+
trt_model = torchtrt.dynamo.compile(
57+
exp_program,
58+
inputs=[input_tensor],
59+
enabled_precisions={torch.float8_e4m3fn},
60+
min_block_size=1,
61+
cache_built_engines=False,
62+
reuse_cached_engines=False,
63+
)
64+
outputs_trt = trt_model(input_tensor)
65+
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)
66+
67+
68+
@unittest.skipIf(
69+
platform.system() != "Linux"
70+
or not importlib.util.find_spec("modelopt")
71+
or Version(metadata.version("nvidia-modelopt")) < Version("0.27.0"),
72+
"modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux",
73+
)
74+
@pytest.mark.unit
75+
def test_base_int8():
76+
import modelopt.torch.quantization as mtq
77+
from modelopt.torch.quantization.utils import export_torch_mode
78+
79+
class SimpleNetwork(torch.nn.Module):
80+
def __init__(self):
81+
super(SimpleNetwork, self).__init__()
82+
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
83+
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
84+
85+
def forward(self, x):
86+
x = self.linear1(x)
87+
x = torch.nn.ReLU()(x)
88+
x = self.linear2(x)
89+
return x
90+
91+
def calibrate_loop(model):
92+
"""Simple calibration function for testing."""
93+
model(input_tensor)
94+
95+
input_tensor = torch.randn(1, 10).cuda()
96+
model = SimpleNetwork().eval().cuda()
97+
98+
quant_cfg = mtq.INT8_DEFAULT_CFG
99+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
100+
# model has INT8 qdq nodes at this point
101+
output_pyt = model(input_tensor)
102+
103+
with torchtrt.logging.debug(), torch.no_grad():
104+
with export_torch_mode():
105+
exp_program = torch.export.export(model, (input_tensor,), strict=False)
106+
trt_model = torchtrt.dynamo.compile(
107+
exp_program,
108+
inputs=[input_tensor],
109+
enabled_precisions={torch.int8},
110+
min_block_size=1,
111+
cache_built_engines=False,
112+
reuse_cached_engines=False,
113+
truncate_double=True,
114+
debug=True,
115+
)
116+
outputs_trt = trt_model(input_tensor)
117+
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)

0 commit comments

Comments
 (0)