Skip to content

Commit 5549da8

Browse files
authored
Enable {conv3d, conv_transpose3d} + bn fusion in pt2e (#2212)
* Enable {conv3d, conv_transpose3d} + bn fusion in pt2e Summary: att, previously only 1d and 2d fusion are supported, this PR adds 3d support Test Plan: python test/quantization/pt2e/test_quantize_pt2e.py -k test_conv3d_bn_relu python test/quantization/pt2e/test_quantize_pt2e.py -k test_conv_transpose3d_bn_relu Reviewers: Subscribers: Tasks: Tags: * comment * fix test
1 parent f04ff57 commit 5549da8

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,192 @@ def validate(self, model: torch.fx.GraphModule) -> None:
23852385
node_list,
23862386
)
23872387

2388+
def test_conv3d_bn_relu(self):
2389+
class BackendAQuantizer(Quantizer):
2390+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2391+
act_qspec = QuantizationSpec(
2392+
dtype=torch.uint8,
2393+
quant_min=0,
2394+
quant_max=255,
2395+
qscheme=torch.per_tensor_affine,
2396+
is_dynamic=False,
2397+
observer_or_fake_quant_ctr=observer.default_observer,
2398+
)
2399+
weight_qspec = QuantizationSpec(
2400+
dtype=torch.int8,
2401+
quant_min=-128,
2402+
quant_max=127,
2403+
qscheme=torch.per_tensor_affine,
2404+
is_dynamic=False,
2405+
observer_or_fake_quant_ctr=observer.default_weight_observer,
2406+
)
2407+
bias_qspec = QuantizationSpec(
2408+
dtype=torch.float32,
2409+
is_dynamic=False,
2410+
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
2411+
)
2412+
# conv_transpose + bn is fused automatically in PTQ (not configurable)
2413+
# so we just need to annotate conv + relu for conv + bn + relu
2414+
# pattern
2415+
for n in model.graph.nodes:
2416+
if (
2417+
n.op != "call_function"
2418+
or n.target != torch.ops.aten.relu.default
2419+
):
2420+
continue
2421+
relu_node = n
2422+
n = n.args[0]
2423+
if (
2424+
n.op != "call_function"
2425+
and n.target != torch.ops.aten.conv3d.input
2426+
):
2427+
continue
2428+
conv_t_node = n
2429+
input_act = conv_t_node.args[0]
2430+
weight = conv_t_node.args[1]
2431+
bias = conv_t_node.args[2]
2432+
conv_t_node.meta["quantization_annotation"] = (
2433+
QuantizationAnnotation(
2434+
input_qspec_map={
2435+
input_act: act_qspec,
2436+
weight: weight_qspec,
2437+
bias: bias_qspec,
2438+
},
2439+
_annotated=True,
2440+
)
2441+
)
2442+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
2443+
output_qspec=act_qspec,
2444+
_annotated=True,
2445+
)
2446+
2447+
def validate(self, model: torch.fx.GraphModule) -> None:
2448+
pass
2449+
2450+
class M(torch.nn.Module):
2451+
def __init__(self):
2452+
super().__init__()
2453+
self.conv = torch.nn.Conv3d(2, 2, 3, padding=1)
2454+
self.bn = torch.nn.BatchNorm3d(2)
2455+
2456+
def forward(self, x):
2457+
return torch.nn.functional.relu(self.bn(self.conv(x)))
2458+
2459+
example_inputs = (torch.randn(1, 2, 2, 5, 5),)
2460+
node_occurrence = {
2461+
# two for input of the first conv, one for output for the first conv
2462+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2463+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2464+
}
2465+
node_list = [
2466+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2467+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2468+
torch.ops.aten.conv3d.default,
2469+
torch.ops.aten.relu.default,
2470+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2471+
]
2472+
model = M().eval()
2473+
self._test_quantizer(
2474+
model,
2475+
example_inputs,
2476+
BackendAQuantizer(),
2477+
node_occurrence,
2478+
node_list,
2479+
)
2480+
2481+
def test_conv_transpose3d_bn_relu(self):
2482+
class BackendAQuantizer(Quantizer):
2483+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2484+
act_qspec = QuantizationSpec(
2485+
dtype=torch.uint8,
2486+
quant_min=0,
2487+
quant_max=255,
2488+
qscheme=torch.per_tensor_affine,
2489+
is_dynamic=False,
2490+
observer_or_fake_quant_ctr=observer.default_observer,
2491+
)
2492+
weight_qspec = QuantizationSpec(
2493+
dtype=torch.int8,
2494+
quant_min=-128,
2495+
quant_max=127,
2496+
qscheme=torch.per_tensor_affine,
2497+
is_dynamic=False,
2498+
observer_or_fake_quant_ctr=observer.default_weight_observer,
2499+
)
2500+
bias_qspec = QuantizationSpec(
2501+
dtype=torch.float32,
2502+
is_dynamic=False,
2503+
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
2504+
)
2505+
# conv_transpose + bn is fused automatically in PTQ (not configurable)
2506+
# so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
2507+
# pattern
2508+
for n in model.graph.nodes:
2509+
if (
2510+
n.op != "call_function"
2511+
or n.target != torch.ops.aten.relu.default
2512+
):
2513+
continue
2514+
relu_node = n
2515+
n = n.args[0]
2516+
if (
2517+
n.op != "call_function"
2518+
and n.target != torch.ops.aten.conv_transposed3d.input
2519+
):
2520+
continue
2521+
conv_t_node = n
2522+
input_act = conv_t_node.args[0]
2523+
weight = conv_t_node.args[1]
2524+
bias = conv_t_node.args[2]
2525+
conv_t_node.meta["quantization_annotation"] = (
2526+
QuantizationAnnotation(
2527+
input_qspec_map={
2528+
input_act: act_qspec,
2529+
weight: weight_qspec,
2530+
bias: bias_qspec,
2531+
},
2532+
_annotated=True,
2533+
)
2534+
)
2535+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
2536+
output_qspec=act_qspec,
2537+
_annotated=True,
2538+
)
2539+
2540+
def validate(self, model: torch.fx.GraphModule) -> None:
2541+
pass
2542+
2543+
class M(torch.nn.Module):
2544+
def __init__(self):
2545+
super().__init__()
2546+
self.conv_t = torch.nn.ConvTranspose3d(2, 2, 3, padding=1)
2547+
self.bn = torch.nn.BatchNorm3d(2)
2548+
2549+
def forward(self, x):
2550+
return torch.nn.functional.relu(self.bn(self.conv_t(x)))
2551+
2552+
example_inputs = (torch.randn(1, 2, 2, 5, 5),)
2553+
node_occurrence = {
2554+
# two for input of the first conv, one for output for the first conv
2555+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2556+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2557+
}
2558+
node_list = [
2559+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2560+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2561+
torch.ops.aten.conv_transpose3d.input,
2562+
torch.ops.aten.relu.default,
2563+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2564+
]
2565+
model = M().eval()
2566+
self._test_quantizer(
2567+
model,
2568+
example_inputs,
2569+
BackendAQuantizer(),
2570+
node_occurrence,
2571+
node_list,
2572+
)
2573+
23882574
def test_multi_users_without_output_observer(self):
23892575
"""
23902576
Test the case in which a node is used by multiple users,

torchao/quantization/pt2e/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def _is_conv_node(n: Node):
626626
return n.op == "call_function" and n.target in [
627627
torch.ops.aten.conv1d.default,
628628
torch.ops.aten.conv2d.default,
629+
torch.ops.aten.conv3d.default,
629630
]
630631

631632

@@ -638,6 +639,8 @@ def _is_conv_transpose_node(n: Node):
638639
torch.ops.aten.conv_transpose1d.default,
639640
torch.ops.aten.conv_transpose2d,
640641
torch.ops.aten.conv_transpose2d.input,
642+
torch.ops.aten.conv_transpose3d,
643+
torch.ops.aten.conv_transpose3d.input,
641644
]
642645

643646

@@ -649,7 +652,7 @@ def _is_conv_or_conv_transpose_node(n: Node):
649652

650653

651654
def _is_conv_transpose_fn(conv_fn: Callable):
652-
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
655+
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d]
653656

654657

655658
def _is_bn_node(n: Node):

0 commit comments

Comments
 (0)