Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ba76f6d

Browse files
committedApr 18, 2025·
trying to add quantization to Flux
1 parent 9e390da commit ba76f6d

File tree

6 files changed

+381
-9
lines changed

6 files changed

+381
-9
lines changed
 
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
import re
5+
6+
import modelopt.torch.opt as mto
7+
import modelopt.torch.quantization as mtq
8+
import torch
9+
import torch_tensorrt
10+
from diffusers import FluxPipeline
11+
from diffusers.models.attention_processor import Attention
12+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
13+
from modelopt.torch.quantization.utils import export_torch_mode
14+
from torch.export._trace import _export
15+
from transformers import AutoModelForCausalLM
16+
17+
# %%
18+
DEVICE = "cuda:0"
19+
pipe = FluxPipeline.from_pretrained(
20+
"black-forest-labs/FLUX.1-dev",
21+
torch_dtype=torch.float32,
22+
)
23+
pipe.transformer = FluxTransformer2DModel(
24+
num_layers=1, num_single_layers=1, guidance_embeds=True
25+
)
26+
27+
pipe.to(DEVICE).to(torch.float32)
28+
# Store the config and transformer backbone
29+
config = pipe.transformer.config
30+
# global backbone
31+
backbone = pipe.transformer
32+
backbone.eval()
33+
34+
35+
def filter_func(name):
36+
pattern = re.compile(
37+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
38+
)
39+
return pattern.match(name) is not None
40+
41+
42+
def generate_image(pipe, prompt, image_name):
43+
seed = 42
44+
image = pipe(
45+
prompt,
46+
output_type="pil",
47+
num_inference_steps=20,
48+
generator=torch.Generator("cuda").manual_seed(seed),
49+
).images[0]
50+
image.save(f"{image_name}.png")
51+
print(f"Image generated using {image_name} model saved as {image_name}.png")
52+
53+
54+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
55+
56+
# %%
57+
# Quantization
58+
59+
60+
def do_calibrate(
61+
pipe,
62+
prompt: str,
63+
) -> None:
64+
"""
65+
Run calibration steps on the pipeline using the given prompts.
66+
"""
67+
image = pipe(
68+
prompt,
69+
output_type="pil",
70+
num_inference_steps=20,
71+
generator=torch.Generator("cuda").manual_seed(0),
72+
).images[0]
73+
74+
75+
def forward_loop(mod):
76+
# Switch the pipeline's backbone, run calibration
77+
pipe.transformer = mod
78+
do_calibrate(
79+
pipe=pipe,
80+
prompt="test",
81+
)
82+
83+
84+
ptq_config = mtq.FP8_DEFAULT_CFG
85+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
86+
mtq.disable_quantizer(backbone, filter_func)
87+
88+
89+
# %%
90+
# Export the backbone using torch.export
91+
# --------------------------------------------------
92+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
93+
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
94+
95+
batch_size = 2
96+
BATCH = torch.export.Dim("batch", min=1, max=2)
97+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
98+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
99+
# To see this recommendation, you can try exporting using min=1, max=4096
100+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
101+
dynamic_shapes = {
102+
"hidden_states": {0: BATCH},
103+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
104+
"pooled_projections": {0: BATCH},
105+
"timestep": {0: BATCH},
106+
"txt_ids": {0: SEQ_LEN},
107+
"img_ids": {0: IMG_ID},
108+
"guidance": {0: BATCH},
109+
"joint_attention_kwargs": {},
110+
"return_dict": None,
111+
}
112+
# The guidance factor is of type torch.float32
113+
dummy_inputs = {
114+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float32).to(
115+
DEVICE
116+
),
117+
"encoder_hidden_states": torch.randn(
118+
(batch_size, 512, 4096), dtype=torch.float32
119+
).to(DEVICE),
120+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float32).to(
121+
DEVICE
122+
),
123+
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
124+
"txt_ids": torch.randn((512, 3), dtype=torch.float32).to(DEVICE),
125+
"img_ids": torch.randn((4096, 3), dtype=torch.float32).to(DEVICE),
126+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
127+
"joint_attention_kwargs": {},
128+
"return_dict": False,
129+
}
130+
131+
# This will create an exported program which is going to be compiled with Torch-TensorRT
132+
with export_torch_mode():
133+
ep = _export(
134+
backbone,
135+
args=(),
136+
kwargs=dummy_inputs,
137+
dynamic_shapes=dynamic_shapes,
138+
strict=False,
139+
allow_complex_guards_as_runtime_asserts=True,
140+
)
141+
142+
with torch_tensorrt.logging.debug():
143+
trt_gm = torch_tensorrt.dynamo.compile(
144+
ep,
145+
inputs=dummy_inputs,
146+
enabled_precisions={torch.float8_e4m3fn},
147+
truncate_double=True,
148+
min_block_size=1,
149+
debug=False,
150+
use_python_runtime=True,
151+
immutable_weights=True,
152+
offload_module_to_cpu=True,
153+
)
154+
155+
156+
del ep
157+
pipe.transformer = trt_gm
158+
pipe.transformer.config = config
159+
160+
161+
# %%
162+
trt_gm.device = torch.device(DEVICE)
163+
# Function which generates images from the flux pipeline
164+
165+
for _ in range(2):
166+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
167+
168+
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

‎examples/apps/flux-quantization.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
import re
5+
6+
import modelopt.torch.opt as mto
7+
import modelopt.torch.quantization as mtq
8+
import torch
9+
import torch_tensorrt
10+
from diffusers import FluxPipeline
11+
from diffusers.models.attention_processor import Attention
12+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
13+
from modelopt.torch.quantization.utils import export_torch_mode
14+
from torch.export._trace import _export
15+
from transformers import AutoModelForCausalLM
16+
17+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
18+
19+
# %%
20+
DEVICE = "cuda:0"
21+
pipe = FluxPipeline.from_pretrained(
22+
"black-forest-labs/FLUX.1-dev",
23+
torch_dtype=torch.float16,
24+
)
25+
pipe.transformer = FluxTransformer2DModel(
26+
num_layers=1, num_single_layers=1, guidance_embeds=True
27+
)
28+
29+
pipe.to(DEVICE).to(torch.float16)
30+
# Store the config and transformer backbone
31+
config = pipe.transformer.config
32+
# global backbone
33+
backbone = pipe.transformer
34+
backbone.eval()
35+
36+
37+
def filter_func(name):
38+
pattern = re.compile(
39+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
40+
)
41+
return pattern.match(name) is not None
42+
43+
44+
def generate_image(pipe, prompt, image_name):
45+
seed = 42
46+
image = pipe(
47+
prompt,
48+
output_type="pil",
49+
num_inference_steps=20,
50+
generator=torch.Generator("cuda").manual_seed(seed),
51+
).images[0]
52+
image.save(f"{image_name}.png")
53+
print(f"Image generated using {image_name} model saved as {image_name}.png")
54+
55+
56+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
57+
58+
# %%
59+
# Quantization
60+
61+
62+
def do_calibrate(
63+
pipe,
64+
prompt: str,
65+
) -> None:
66+
"""
67+
Run calibration steps on the pipeline using the given prompts.
68+
"""
69+
image = pipe(
70+
prompt,
71+
output_type="pil",
72+
num_inference_steps=20,
73+
generator=torch.Generator("cuda").manual_seed(0),
74+
).images[0]
75+
76+
77+
def forward_loop(mod):
78+
# Switch the pipeline's backbone, run calibration
79+
pipe.transformer = mod
80+
do_calibrate(
81+
pipe=pipe,
82+
prompt="test",
83+
)
84+
85+
86+
ptq_config = mtq.FP8_DEFAULT_CFG
87+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
88+
mtq.disable_quantizer(backbone, filter_func)
89+
90+
batch_size = 1
91+
BATCH = torch.export.Dim("batch", min=1, max=2)
92+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
93+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
94+
# To see this recommendation, you can try exporting using min=1, max=4096
95+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
96+
dynamic_shapes = {
97+
"hidden_states": {0: BATCH},
98+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
99+
"pooled_projections": {0: BATCH},
100+
"timestep": {0: BATCH},
101+
"txt_ids": {0: SEQ_LEN},
102+
"img_ids": {0: IMG_ID},
103+
"guidance": {0: BATCH},
104+
"joint_attention_kwargs": {},
105+
"return_dict": None,
106+
}
107+
# The guidance factor is of type torch.float32
108+
dummy_inputs = {
109+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
110+
DEVICE
111+
),
112+
"encoder_hidden_states": torch.randn(
113+
(batch_size, 512, 4096), dtype=torch.float16
114+
).to(DEVICE),
115+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
116+
DEVICE
117+
),
118+
"timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE),
119+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
120+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
121+
"guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE),
122+
"joint_attention_kwargs": {},
123+
"return_dict": False,
124+
}
125+
126+
# This will create an exported program which is going to be compiled with Torch-TensorRT
127+
with export_torch_mode():
128+
ep = _export(
129+
backbone,
130+
args=(),
131+
kwargs=dummy_inputs,
132+
# dynamic_shapes=dynamic_shapes,
133+
strict=False,
134+
allow_complex_guards_as_runtime_asserts=True,
135+
)
136+
137+
with torch_tensorrt.logging.debug():
138+
trt_gm = torch_tensorrt.dynamo.compile(
139+
ep,
140+
inputs=dummy_inputs,
141+
enabled_precisions={torch.float8_e4m3fn, torch.float16},
142+
truncate_double=True,
143+
min_block_size=1,
144+
debug=True,
145+
use_python_runtime=True,
146+
immutable_weights=True,
147+
offload_module_to_cpu=True,
148+
)
149+
150+
151+
del ep
152+
pipe.transformer = trt_gm
153+
pipe.transformer.config = config
154+
155+
156+
# %%
157+
trt_gm.device = torch.device(DEVICE)
158+
# Function which generates images from the flux pipeline
159+
160+
for _ in range(2):
161+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
162+
163+
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

‎py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,9 @@ def aten_ops_neg(
597597
)
598598
else:
599599

600-
@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
600+
@dynamo_tensorrt_converter(
601+
torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True
602+
)
601603
def aten_ops_quantize_op(
602604
ctx: ConversionContext,
603605
target: Target,

‎py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,31 @@
66
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
9+
from torch_tensorrt.dynamo.conversion import impl
910
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1011
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_torch
1112
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1213
from torch_tensorrt.fx.types import TRTTensor
1314

1415

16+
def get_ir(target: Target) -> SourceIR:
17+
target_module = getattr(target, "__module__", "None")
18+
if any(
19+
target_module.startswith(prefix)
20+
for prefix in ("torch.ops.aten", "torch._ops.aten")
21+
):
22+
return SourceIR.ATEN
23+
elif any(
24+
target_module.startswith(prefix)
25+
for prefix in ("torch.ops.prims", "torch._ops.prims")
26+
):
27+
return SourceIR.PRIM
28+
elif target_module.startswith("torch.nn"):
29+
return SourceIR.NN
30+
31+
return SourceIR.UNKNOWN
32+
33+
1534
def quantize(
1635
ctx: ConversionContext,
1736
target: Target,
@@ -44,20 +63,35 @@ def quantize(
4463
elif num_bits == 8 and exponent_bits == 4:
4564
max_bound = 448
4665

47-
amax = to_torch(amax, None)
48-
scale = torch.divide(amax, max_bound)
49-
scale = get_trt_tensor(ctx, scale, name + "_scale")
66+
if not isinstance(amax, trt.ITensor):
67+
amax = to_torch(amax, None)
68+
scale = torch.divide(amax, max_bound)
69+
scale = get_trt_tensor(ctx, scale, name + "_scale")
70+
else:
71+
scale = impl.elementwise.div(
72+
ctx,
73+
target,
74+
get_ir(target),
75+
name,
76+
amax,
77+
max_bound,
78+
)
79+
scale = get_trt_tensor(ctx, scale, name + "_scale")
80+
5081
# Add Q node
51-
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
5282
if num_bits == 8 and exponent_bits == 0:
53-
quantize_layer.set_output_type(0, trt.DataType.INT8)
83+
dtype = trt.DataType.INT8
5484
elif num_bits == 8 and exponent_bits == 4:
55-
quantize_layer.set_output_type(0, trt.DataType.FP8)
85+
dtype = trt.DataType.FP8
86+
87+
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
5688

5789
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
5890
q_output = quantize_layer.get_output(0)
5991
# Add DQ node
60-
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
92+
dequantize_layer = ctx.net.add_dequantize(
93+
q_output, scale, output_type=input_tensor.dtype
94+
)
6195
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
6296
if num_bits == 8 and exponent_bits == 0:
6397
dequantize_layer.precision = trt.DataType.INT8

‎py/torch_tensorrt/dynamo/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
419419
"""
420420
Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64
421421
"""
422-
if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)):
422+
if isinstance(tensor, (torch.Tensor, FakeTensor)):
423+
return tensor.dtype
424+
elif isinstance(tensor, (int, float, bool)):
423425
return torch.tensor(tensor).dtype
424426
elif isinstance(tensor, torch.SymInt):
425427
return torch.int64
@@ -791,6 +793,8 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
791793
output_dtypes.append(dtype.float32)
792794
else:
793795
output_dtypes.append(dtype._from(output_meta.dtype))
796+
elif isinstance(output_meta, torch.SymInt):
797+
output_dtypes.append(dtype.int64)
794798
elif "tensor_meta" in output.meta:
795799
output_meta = output.meta["tensor_meta"]
796800
output_dtypes.append(dtype._from(output_meta.dtype))

‎tools/perf/Flux/create_env.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" di
2424

2525
pip install notebook
2626
pip install gradio safetensors peft pyinstrument
27+
pip install nvidia-modelopt onnx torchprofile pulp onnxruntime

0 commit comments

Comments
 (0)
Please sign in to comment.