Skip to content

Commit a4ff6bb

Browse files
committed
Fixed fp16 quantization error
1 parent 18b6455 commit a4ff6bb

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

examples/apps/flux-demo.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
parser.add_argument(
1616
"--dtype",
1717
choices=["fp8", "int8", "fp16"],
18-
default="int8",
18+
default="fp16",
1919
help="Select the data type to use (fp8 or int8 or fp16)",
2020
)
2121
args = parser.parse_args()
@@ -30,7 +30,7 @@
3030
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
3131
elif args.dtype == "fp16":
3232
enabled_precisions = {torch.float16}
33-
print(f"\nUsing {args.dtype} quantization")
33+
print(f"\nUsing {args.dtype}")
3434

3535

3636
DEVICE = "cuda:0"
@@ -152,6 +152,9 @@ def load_lora(path):
152152
print("Refitting Finished!")
153153

154154

155+
load_lora("/home/TensorRT/examples/apps/NGRVNG.safetensors")
156+
157+
155158
# Create Gradio interface
156159
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
157160
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def quantize(
6666
if not isinstance(amax, trt.ITensor):
6767
amax = to_torch(amax, None)
6868
scale = torch.divide(amax, max_bound)
69-
scale = get_trt_tensor(ctx, scale, name + "_scale")
69+
scale = get_trt_tensor(ctx, scale, name + "_scale", dtype=torch.float32)
7070
else:
7171
scale = impl.elementwise.div(
7272
ctx,
@@ -76,7 +76,7 @@ def quantize(
7676
amax,
7777
max_bound,
7878
)
79-
scale = get_trt_tensor(ctx, scale, name + "_scale")
79+
scale = get_trt_tensor(ctx, scale, name + "_scale", dtype=torch.float32)
8080

8181
# Add Q node
8282
if num_bits == 8 and exponent_bits == 0:
@@ -96,7 +96,6 @@ def quantize(
9696
q_output, scale, output_type=input_tensor.dtype
9797
)
9898
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
99-
dequantize_layer.precision = dtype
10099

101100
dq_output = dequantize_layer.get_output(0)
102101

0 commit comments

Comments
 (0)