Skip to content

Commit 9f803ee

Browse files
committed
add flux
1 parent 3f09d97 commit 9f803ee

File tree

8 files changed

+437
-24
lines changed

8 files changed

+437
-24
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,9 @@ def aten_ops_neg(
597597
)
598598
else:
599599

600-
@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
600+
@dynamo_tensorrt_converter(
601+
torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True
602+
)
601603
def aten_ops_quantize_op(
602604
ctx: ConversionContext,
603605
target: Target,
@@ -650,6 +652,38 @@ def aten_ops_dynamic_block_quantize_op(
650652
)
651653

652654

655+
def attention_validator(
656+
node: Node, settings: Optional[CompilationSettings] = None
657+
) -> bool:
658+
# Currently, `attn_mask` is not supported
659+
return args_bounds_check(node.args, 3) is None
660+
661+
662+
@dynamo_tensorrt_converter(
663+
torch.nn.functional.scaled_dot_product_attention,
664+
capability_validator=attention_validator,
665+
supports_dynamic_shapes=True,
666+
)
667+
def tensorrt_scaled_dot_product_attention(
668+
ctx: ConversionContext,
669+
target: Target,
670+
args: Tuple[Argument, ...],
671+
kwargs: Dict[str, Argument],
672+
name: str,
673+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
674+
return impl.attention.scaled_dot_product_attention(
675+
ctx,
676+
target,
677+
SourceIR.TORCHTRT_LOWERED,
678+
name,
679+
args[0],
680+
args[1],
681+
args[2],
682+
args_bounds_check(args, 5, False),
683+
kwargs.get("scale", None),
684+
)
685+
686+
653687
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
654688
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
655689
def aten_ops_squeeze(

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
activation,
33
addmm,
44
arange,
5+
attention,
56
cast,
67
cat,
78
condition,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import math
2+
from typing import Optional, Union
3+
4+
import numpy as np
5+
import tensorrt as trt
6+
from torch.fx.node import Target
7+
from torch_tensorrt._enums import dtype
8+
from torch_tensorrt.dynamo.conversion import impl
9+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
SourceIR,
12+
cast_trt_tensor,
13+
get_trt_tensor,
14+
)
15+
from torch_tensorrt.fx.types import TRTTensor
16+
17+
18+
def tril(
19+
ctx: ConversionContext,
20+
target: Union[Target, str],
21+
source_ir: Optional[SourceIR],
22+
name: str,
23+
input: TRTTensor,
24+
) -> TRTTensor:
25+
# the lower triangle of the tensor means the rows greater than and equal to the cols
26+
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
27+
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
28+
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
29+
arange_tensor = impl.arange.arange(
30+
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
31+
)
32+
# get the rows
33+
row_tensor = impl.elementwise.trunc_div(
34+
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
35+
)
36+
# get the cols
37+
col_tensor = impl.elementwise.fmod(
38+
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
39+
)
40+
cond = impl.elementwise.ge(
41+
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
42+
)
43+
return impl.shuffle.reshape(
44+
ctx, target, source_ir, name + "_reshape", cond, [row, col]
45+
)
46+
47+
48+
def scaled_dot_product_attention(
49+
ctx: ConversionContext,
50+
target: Union[Target, str],
51+
source_ir: Optional[SourceIR],
52+
name: str,
53+
query: TRTTensor,
54+
key: TRTTensor,
55+
value: TRTTensor,
56+
is_causal: bool,
57+
scale: Optional[float],
58+
) -> TRTTensor:
59+
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
60+
mm = impl.matmul.matrix_multiply(
61+
ctx,
62+
target,
63+
source_ir,
64+
name + "_mm",
65+
query,
66+
key,
67+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
68+
)
69+
if scale is None:
70+
scale = query.shape[-1]
71+
if scale < 0:
72+
# dynamic shape
73+
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
74+
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
75+
else:
76+
# static shape
77+
sqrt_scaled = math.sqrt(scale)
78+
scaled = impl.elementwise.div(
79+
ctx,
80+
target,
81+
source_ir,
82+
name + "_scale",
83+
mm,
84+
sqrt_scaled,
85+
)
86+
else:
87+
scaled = impl.elementwise.mul(
88+
ctx,
89+
target,
90+
source_ir,
91+
name + "_scale",
92+
mm,
93+
scale,
94+
)
95+
96+
if is_causal:
97+
L, S = query.shape[-2], key.shape[-2]
98+
if L >= 0 and S >= 0:
99+
# static shape
100+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
101+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
102+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
103+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
104+
else:
105+
# if any of the L or S is dynamic shape
106+
if L < 0:
107+
L = impl.shape.shape(
108+
ctx, target, source_ir, name + "_shape_0", query, -2
109+
)
110+
if S < 0:
111+
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)
112+
113+
LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)
114+
115+
# this is to generate a tensor which has shape (L, S), type is int32
116+
arange_tensor = impl.arange.arange(
117+
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
118+
)
119+
shape_tensor = impl.shuffle.reshape(
120+
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
121+
)
122+
123+
# since we want our attn_bias to be in float32, so cast it to float32
124+
shape_tensor = cast_trt_tensor(
125+
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
126+
)
127+
128+
# initialize the attn_bias as the zeros tensor
129+
attn_bias = impl.elementwise.mul(
130+
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
131+
)
132+
133+
# generate the mask tensor
134+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
135+
temp_mask = impl.unary.logical_not(
136+
ctx, target, source_ir, name + "_logical_not", tril_tensor
137+
)
138+
inf_tensor = impl.elementwise.mul(
139+
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
140+
)
141+
cond = impl.elementwise.eq(
142+
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
143+
)
144+
# mask out the certain part of the attn_bias
145+
attn_bias = impl.condition.select(
146+
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
147+
)
148+
149+
scaled = impl.elementwise.add(
150+
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
151+
)
152+
153+
softmax = impl.normalization.softmax(
154+
ctx, target, source_ir, name + "_softmax", scaled, -1, False
155+
)
156+
out = impl.matmul.matrix_multiply(
157+
ctx,
158+
target,
159+
source_ir,
160+
name + "_out",
161+
softmax,
162+
value,
163+
)
164+
165+
return out

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import tensorrt as trt
55
import torch
6+
import torch_tensorrt.dynamo.conversion.impl as impl
67
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
78
from torch.fx.node import Target
89
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -28,42 +29,53 @@ def quantize(
2829
"""
2930

3031
with unset_fake_temporarily():
31-
if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in (
32-
trt.float32,
33-
trt.float16,
34-
):
35-
raise ValueError(
36-
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
37-
)
32+
if isinstance(input_tensor, (torch.Tensor, TRTTensor)):
33+
input_tensor = get_trt_tensor(ctx, input_tensor, name)
34+
if input_tensor.dtype not in (
35+
trt.float32,
36+
trt.float16,
37+
):
38+
raise ValueError(
39+
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
40+
)
3841
if num_bits != 8 or exponent_bits not in (0, 4):
3942
raise ValueError(
4043
f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}"
4144
)
45+
else:
46+
raise ValueError(
47+
f"quantize converter received an input of {type(input_tensor)} type. Supported types: torch.Tensor | TRTTensor"
48+
)
49+
4250
if num_bits == 8 and exponent_bits == 0:
4351
max_bound = 127
4452
elif num_bits == 8 and exponent_bits == 4:
4553
max_bound = 448
4654

47-
amax = to_torch(amax, None)
48-
scale = torch.divide(amax, max_bound)
49-
scale = get_trt_tensor(ctx, scale, name + "_scale")
50-
# Add Q node
51-
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
55+
if not isinstance(amax, trt.ITensor):
56+
amax = to_torch(amax, None)
57+
scale = torch.divide(amax, max_bound)
58+
scale = get_trt_tensor(ctx, amax, name + "_scale")
59+
else:
60+
scale = impl.elementwise_divide(
61+
ctx, target, source_ir, name + "_scale", amax, max_bound
62+
)
63+
5264
if num_bits == 8 and exponent_bits == 0:
53-
quantize_layer.set_output_type(0, trt.DataType.INT8)
65+
dtype = trt.DataType.INT8
5466
elif num_bits == 8 and exponent_bits == 4:
55-
quantize_layer.set_output_type(0, trt.DataType.FP8)
67+
dtype = trt.DataType.FP8
5668

69+
# Add Q node
70+
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
5771
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
5872
q_output = quantize_layer.get_output(0)
5973
# Add DQ node
60-
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
74+
dequantize_layer = ctx.net.add_dequantize(
75+
q_output, scale, output_type=input_tensor.dtype
76+
)
77+
dequantize_layer.to_type = input_tensor.dtype
6178
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
62-
if num_bits == 8 and exponent_bits == 0:
63-
dequantize_layer.precision = trt.DataType.INT8
64-
elif num_bits == 8 and exponent_bits == 4:
65-
# Set DQ layer precision to FP8
66-
dequantize_layer.precision = trt.DataType.FP8
6779
dq_output = dequantize_layer.get_output(0)
6880

6981
return dq_output

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
)
1212
from torch_tensorrt.dynamo.types import TRTTensor
1313

14+
from packaging import version as pkg_version
15+
1416
logger = logging.getLogger(__name__)
1517

1618

@@ -24,7 +26,7 @@ def unsqueeze(
2426
) -> TRTTensor:
2527
from importlib.metadata import version
2628

27-
if version("tensorrt") < "10.7.0":
29+
if pkg_version.parse(version("tensorrt")) < pkg_version.parse("10.7.0"):
2830
logger.warning(
2931
f"IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version: {version('tensorrt')}"
3032
)

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .constant_folding import constant_fold
1010
from .fuse_distributed_ops import fuse_distributed_ops
1111
from .fuse_prims_broadcast import fuse_prims_broadcast
12+
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1213
from .pass_manager import DynamoPassManager
1314
from .remove_assert_nodes import remove_assert_nodes
1415
from .remove_detach import remove_detach
@@ -23,6 +24,7 @@
2324
repair_input_as_output,
2425
fuse_prims_broadcast,
2526
replace_max_pool_with_indices,
27+
lower_scaled_dot_product_attention,
2628
remove_assert_nodes,
2729
accumulate_fp32_matmul,
2830
remove_num_users_is_0_nodes,

0 commit comments

Comments
 (0)