Skip to content

Commit 5a2213e

Browse files
committed
add constant fold
1 parent d439d96 commit 5a2213e

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,11 @@ def dynamic_block_quantize(
108108
# Add Q node
109109
dynamic_quantize_layer = ctx.net.add_dynamic_quantize(
110110
input_tensor,
111-
axis=-1,
112-
block_size=16,
113-
output_type=trt.DataType.FP4,
114-
scale_type=trt.DataType.FP8,
111+
-1,
112+
16,
113+
trt.DataType.FP4,
114+
trt.DataType.FP8,
115115
)
116-
dynamic_quantize_layer.set_output_type(0, trt.DataType.FP4)
117116

118117
set_layer_name(
119118
dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
101101

102102
# TODO: Update this function when quantization is added
103103
def is_impure(self, node: torch.fx.node.Node) -> bool:
104+
if node.target in (
105+
torch.ops.tensorrt.quantize_op.default,
106+
torch.ops.tensorrt.dynamic_block_quantize_op.default,
107+
):
108+
return True
104109
return False

0 commit comments

Comments
 (0)