Skip to content

Commit 057f35a

Browse files
committed
test
1 parent fcf0c12 commit 057f35a

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

examples/dynamo/vgg16_ptq.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ def calibrate_loop(model):
200200
quant_cfg = mtq.INT8_DEFAULT_CFG
201201
elif args.quantize_type == "fp8":
202202
quant_cfg = mtq.FP8_DEFAULT_CFG
203-
elif args.quantize_type == "fp4":
204-
quant_cfg = mtq.NVFP4_DEFAULT_CFG
205203
# PTQ with in-place replacement to quantized modules
206204
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
207205
# model has FP8 qdq nodes at this point
@@ -241,8 +239,6 @@ def calibrate_loop(model):
241239
enabled_precisions = {torch.int8}
242240
elif args.quantize_type == "fp8":
243241
enabled_precisions = {torch.float8_e4m3fn}
244-
elif args.quantize_type == "fp4":
245-
enabled_precisions = {torch.float4_e2m1fn_x2}
246242
trt_model = torchtrt.dynamo.compile(
247243
exp_program,
248244
inputs=[input_tensor],

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def dynamic_block_quantize(
8888
Adds quantize and dequantize ops (QDQ) which quantize to FP4 based
8989
on the output_type set and dequantizes them back.
9090
"""
91-
91+
print(
92+
f"dynamic_block_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}"
93+
)
9294
with unset_fake_temporarily():
9395
if not isinstance(input_tensor, TRTTensor):
9496
input_tensor = get_trt_tensor(
@@ -114,18 +116,28 @@ def dynamic_block_quantize(
114116
# Add Q node
115117
dynamic_quantize_layer = ctx.net.add_dynamic_quantize(
116118
input_tensor,
117-
-1,
118-
16,
119-
trt.DataType.FP4,
120-
trt.DataType.FP8,
119+
axis=1,
120+
block_size=16,
121+
output_type=trt.DataType.FP4,
122+
scale_type=trt.DataType.FP8,
121123
)
122-
123124
set_layer_name(
124125
dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir
125126
)
126127
q_output = dynamic_quantize_layer.get_output(0)
127-
# Add DQ node
128-
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
128+
q_scale = dynamic_quantize_layer.get_output(1)
129+
130+
# Add double DQ node
131+
scale_dequantize_layer = ctx.net.add_dequantize(q_scale, scale)
132+
scale_dequantize_layer.axis = 0
133+
set_layer_name(
134+
scale_dequantize_layer, target, name + "_scale_dequantize", source_ir
135+
)
136+
scale_dequantize_layer.precision = trt.DataType.FP8
137+
scale_dq_output = scale_dequantize_layer.get_output(0)
138+
139+
dequantize_layer = ctx.net.add_dequantize(q_output, scale_dq_output)
140+
dequantize_layer.axis = 1
129141
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
130142
dequantize_layer.precision = trt.DataType.FP4
131143
dq_output = dequantize_layer.get_output(0)

tests/py/dynamo/models/test_models_export.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ def test_resnet18_half(ir):
199199
torch._dynamo.reset()
200200

201201

202-
@unittest.skipIf(
203-
torch.cuda.get_device_capability() < (10, 0),
204-
"FP4 quantization requires compute capability 10.0 or later",
205-
)
202+
# @unittest.skipIf(
203+
# torch.cuda.get_device_capability() < (10, 0),
204+
# "FP4 quantization requires compute capability 10.0 or later",
205+
# )
206206
@unittest.skipIf(
207207
not importlib.util.find_spec("modelopt"),
208208
"ModelOpt is required to run this test",
@@ -215,20 +215,17 @@ def test_base_fp4(ir):
215215
class SimpleNetwork(torch.nn.Module):
216216
def __init__(self):
217217
super(SimpleNetwork, self).__init__()
218-
self.linear1 = torch.nn.Linear(in_features=32, out_features=16)
219-
self.linear2 = torch.nn.Linear(in_features=16, out_features=1)
218+
self.linear1 = torch.nn.Linear(in_features=16, out_features=5)
220219

221220
def forward(self, x):
222221
x = self.linear1(x)
223-
x = torch.nn.ReLU()(x)
224-
x = self.linear2(x)
225222
return x
226223

227224
def calibrate_loop(model):
228225
"""Simple calibration function for testing."""
229226
model(input_tensor)
230227

231-
input_tensor = torch.randn(1, 32).cuda()
228+
input_tensor = torch.randn(1, 16).cuda()
232229
model = SimpleNetwork().eval().cuda()
233230

234231
quant_cfg = mtq.NVFP4_DEFAULT_CFG
@@ -283,7 +280,6 @@ def calibrate_loop(model):
283280

284281
input_tensor = torch.randn(1, 10).cuda()
285282
model = SimpleNetwork().eval().cuda()
286-
287283
quant_cfg = mtq.FP8_DEFAULT_CFG
288284
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
289285
# model has FP8 qdq nodes at this point

0 commit comments

Comments
 (0)