Skip to content

Commit 02763b6

Browse files
cccclaifacebook-github-bot
authored andcommitted
Patch the _is_conv_node function (#2257)
Summary: Add the conv padding ops in torch/ao only, will add a separate PR for the ones in pytorch Differential Revision: D75323215
1 parent 1017c7e commit 02763b6

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,7 +2570,125 @@ def forward(self, x):
25702570
node_occurrence,
25712571
node_list,
25722572
)
2573+
2574+
def test_conv_padding_bn_relu(self):
2575+
class BackendAQuantizer(Quantizer):
2576+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2577+
act_qspec = QuantizationSpec(
2578+
dtype=torch.uint8,
2579+
quant_min=0,
2580+
quant_max=255,
2581+
qscheme=torch.per_tensor_affine,
2582+
is_dynamic=False,
2583+
observer_or_fake_quant_ctr=observer.default_observer,
2584+
)
2585+
weight_qspec = QuantizationSpec(
2586+
dtype=torch.int8,
2587+
quant_min=-128,
2588+
quant_max=127,
2589+
qscheme=torch.per_tensor_affine,
2590+
is_dynamic=False,
2591+
observer_or_fake_quant_ctr=observer.default_weight_observer,
2592+
)
2593+
bias_qspec = QuantizationSpec(
2594+
dtype=torch.float32,
2595+
is_dynamic=False,
2596+
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
2597+
)
25732598

2599+
for n in model.graph.nodes:
2600+
if (
2601+
n.op != "call_function"
2602+
or n.target != torch.ops.aten.relu.default
2603+
):
2604+
continue
2605+
relu_node = n
2606+
n = n.args[0]
2607+
2608+
# Check for any of the conv operations
2609+
conv_ops = [
2610+
torch.ops.aten.conv1d.padding,
2611+
torch.ops.aten.conv2d.padding,
2612+
torch.ops.aten.conv3d.padding,
2613+
]
2614+
if n.op != "call_function" or n.target not in conv_ops:
2615+
continue
2616+
2617+
conv_node = n
2618+
input_act = conv_node.args[0]
2619+
weight = conv_node.args[1]
2620+
bias = conv_node.args[2]
2621+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
2622+
input_qspec_map={
2623+
input_act: act_qspec,
2624+
weight: weight_qspec,
2625+
bias: bias_qspec,
2626+
},
2627+
_annotated=True,
2628+
)
2629+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
2630+
output_qspec=act_qspec,
2631+
_annotated=True,
2632+
)
2633+
2634+
def validate(self, model: torch.fx.GraphModule) -> None:
2635+
pass
2636+
2637+
# Test cases for Conv1d, Conv2d, Conv3d
2638+
test_cases = [
2639+
{
2640+
"conv_type": torch.nn.Conv1d,
2641+
"bn_type": torch.nn.BatchNorm1d,
2642+
"example_input": (torch.randn(1, 3, 5),),
2643+
"conv_op": torch.ops.aten.conv1d.padding,
2644+
},
2645+
{
2646+
"conv_type": torch.nn.Conv2d,
2647+
"bn_type": torch.nn.BatchNorm2d,
2648+
"example_input": (torch.randn(1, 3, 5, 5),),
2649+
"conv_op": torch.ops.aten.conv2d.padding,
2650+
},
2651+
{
2652+
"conv_type": torch.nn.Conv3d,
2653+
"bn_type": torch.nn.BatchNorm3d,
2654+
"example_input": (torch.randn(1, 3, 5, 5, 5),),
2655+
"conv_op": torch.ops.aten.conv3d.padding,
2656+
},
2657+
]
2658+
2659+
for test_case in test_cases:
2660+
with self.subTest(conv_type=test_case["conv_type"].__name__):
2661+
2662+
class M(torch.nn.Module):
2663+
def __init__(self):
2664+
super().__init__()
2665+
self.conv = test_case["conv_type"](3, 3, 3, padding="same")
2666+
self.bn = test_case["bn_type"](3)
2667+
2668+
def forward(self, x):
2669+
return torch.nn.functional.relu(self.bn(self.conv(x)))
2670+
2671+
node_occurrence = {
2672+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2673+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2674+
}
2675+
node_list = [
2676+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2677+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2678+
test_case["conv_op"],
2679+
torch.ops.aten.relu.default,
2680+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2681+
]
2682+
2683+
model = M().eval()
2684+
self._test_quantizer(
2685+
model,
2686+
test_case["example_input"],
2687+
BackendAQuantizer(),
2688+
node_occurrence,
2689+
node_list,
2690+
)
2691+
25742692
def test_multi_users_without_output_observer(self):
25752693
"""
25762694
Test the case in which a node is used by multiple users,

torchao/quantization/pt2e/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,11 @@ def _is_conv_node(n: Node):
625625
"""
626626
return n.op == "call_function" and n.target in [
627627
torch.ops.aten.conv1d.default,
628+
torch.ops.aten.conv1d.padding,
628629
torch.ops.aten.conv2d.default,
630+
torch.ops.aten.conv2d.padding,
629631
torch.ops.aten.conv3d.default,
632+
torch.ops.aten.conv3d.padding,
630633
]
631634

632635

0 commit comments

Comments
 (0)