@@ -2570,6 +2570,125 @@ def forward(self, x):
2570
2570
node_occurrence ,
2571
2571
node_list ,
2572
2572
)
2573
+
2574
+ def test_conv_padding_bn_relu (self ):
2575
+ class BackendAQuantizer (Quantizer ):
2576
+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2577
+ act_qspec = QuantizationSpec (
2578
+ dtype = torch .uint8 ,
2579
+ quant_min = 0 ,
2580
+ quant_max = 255 ,
2581
+ qscheme = torch .per_tensor_affine ,
2582
+ is_dynamic = False ,
2583
+ observer_or_fake_quant_ctr = observer .default_observer ,
2584
+ )
2585
+ weight_qspec = QuantizationSpec (
2586
+ dtype = torch .int8 ,
2587
+ quant_min = - 128 ,
2588
+ quant_max = 127 ,
2589
+ qscheme = torch .per_tensor_affine ,
2590
+ is_dynamic = False ,
2591
+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2592
+ )
2593
+ bias_qspec = QuantizationSpec (
2594
+ dtype = torch .float32 ,
2595
+ is_dynamic = False ,
2596
+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2597
+ )
2598
+
2599
+ for n in model .graph .nodes :
2600
+ if (
2601
+ n .op != "call_function"
2602
+ or n .target != torch .ops .aten .relu .default
2603
+ ):
2604
+ continue
2605
+ relu_node = n
2606
+ n = n .args [0 ]
2607
+
2608
+ # Check for any of the conv operations
2609
+ conv_ops = [
2610
+ torch .ops .aten .conv1d .padding ,
2611
+ torch .ops .aten .conv2d .padding ,
2612
+ torch .ops .aten .conv3d .padding
2613
+ ]
2614
+ if n .op != "call_function" or n .target not in conv_ops :
2615
+ continue
2616
+
2617
+ conv_node = n
2618
+ input_act = conv_node .args [0 ]
2619
+ weight = conv_node .args [1 ]
2620
+ bias = conv_node .args [2 ]
2621
+ conv_node .meta ["quantization_annotation" ] = (
2622
+ QuantizationAnnotation (
2623
+ input_qspec_map = {
2624
+ input_act : act_qspec ,
2625
+ weight : weight_qspec ,
2626
+ bias : bias_qspec ,
2627
+ },
2628
+ _annotated = True ,
2629
+ )
2630
+ )
2631
+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2632
+ output_qspec = act_qspec ,
2633
+ _annotated = True ,
2634
+ )
2635
+
2636
+ def validate (self , model : torch .fx .GraphModule ) -> None :
2637
+ pass
2638
+
2639
+ # Test cases for Conv1d, Conv2d, Conv3d
2640
+ test_cases = [
2641
+ {
2642
+ "conv_type" : torch .nn .Conv1d ,
2643
+ "bn_type" : torch .nn .BatchNorm1d ,
2644
+ "example_input" : (torch .randn (1 , 3 , 5 ),),
2645
+ "conv_op" : torch .ops .aten .conv1d .padding ,
2646
+ },
2647
+ {
2648
+ "conv_type" : torch .nn .Conv2d ,
2649
+ "bn_type" : torch .nn .BatchNorm2d ,
2650
+ "example_input" : (torch .randn (1 , 3 , 5 , 5 ),),
2651
+ "conv_op" : torch .ops .aten .conv2d .padding ,
2652
+ },
2653
+ {
2654
+ "conv_type" : torch .nn .Conv3d ,
2655
+ "bn_type" : torch .nn .BatchNorm3d ,
2656
+ "example_input" : (torch .randn (1 , 3 , 5 , 5 , 5 ),),
2657
+ "conv_op" : torch .ops .aten .conv3d .padding ,
2658
+ },
2659
+ ]
2660
+
2661
+ for test_case in test_cases :
2662
+ with self .subTest (conv_type = test_case ["conv_type" ].__name__ ):
2663
+ class M (torch .nn .Module ):
2664
+ def __init__ (self ):
2665
+ super ().__init__ ()
2666
+ self .conv = test_case ["conv_type" ](3 , 3 , 3 , padding = "same" )
2667
+ self .bn = test_case ["bn_type" ](3 )
2668
+
2669
+ def forward (self , x ):
2670
+ return torch .nn .functional .relu (self .bn (self .conv (x )))
2671
+
2672
+ node_occurrence = {
2673
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2674
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2675
+ }
2676
+ node_list = [
2677
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2678
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2679
+ test_case ["conv_op" ],
2680
+ torch .ops .aten .relu .default ,
2681
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2682
+ ]
2683
+
2684
+ model = M ().eval ()
2685
+ self ._test_quantizer (
2686
+ model ,
2687
+ test_case ["example_input" ],
2688
+ BackendAQuantizer (),
2689
+ node_occurrence ,
2690
+ node_list ,
2691
+ )
2573
2692
2574
2693
def test_multi_users_without_output_observer (self ):
2575
2694
"""
0 commit comments