@@ -2570,7 +2570,125 @@ def forward(self, x):
25702570 node_occurrence ,
25712571 node_list ,
25722572 )
2573+
2574+ def test_conv_padding_bn_relu (self ):
2575+ class BackendAQuantizer (Quantizer ):
2576+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2577+ act_qspec = QuantizationSpec (
2578+ dtype = torch .uint8 ,
2579+ quant_min = 0 ,
2580+ quant_max = 255 ,
2581+ qscheme = torch .per_tensor_affine ,
2582+ is_dynamic = False ,
2583+ observer_or_fake_quant_ctr = observer .default_observer ,
2584+ )
2585+ weight_qspec = QuantizationSpec (
2586+ dtype = torch .int8 ,
2587+ quant_min = - 128 ,
2588+ quant_max = 127 ,
2589+ qscheme = torch .per_tensor_affine ,
2590+ is_dynamic = False ,
2591+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2592+ )
2593+ bias_qspec = QuantizationSpec (
2594+ dtype = torch .float32 ,
2595+ is_dynamic = False ,
2596+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2597+ )
25732598
2599+ for n in model .graph .nodes :
2600+ if (
2601+ n .op != "call_function"
2602+ or n .target != torch .ops .aten .relu .default
2603+ ):
2604+ continue
2605+ relu_node = n
2606+ n = n .args [0 ]
2607+
2608+ # Check for any of the conv operations
2609+ conv_ops = [
2610+ torch .ops .aten .conv1d .padding ,
2611+ torch .ops .aten .conv2d .padding ,
2612+ torch .ops .aten .conv3d .padding ,
2613+ ]
2614+ if n .op != "call_function" or n .target not in conv_ops :
2615+ continue
2616+
2617+ conv_node = n
2618+ input_act = conv_node .args [0 ]
2619+ weight = conv_node .args [1 ]
2620+ bias = conv_node .args [2 ]
2621+ conv_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2622+ input_qspec_map = {
2623+ input_act : act_qspec ,
2624+ weight : weight_qspec ,
2625+ bias : bias_qspec ,
2626+ },
2627+ _annotated = True ,
2628+ )
2629+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2630+ output_qspec = act_qspec ,
2631+ _annotated = True ,
2632+ )
2633+
2634+ def validate (self , model : torch .fx .GraphModule ) -> None :
2635+ pass
2636+
2637+ # Test cases for Conv1d, Conv2d, Conv3d
2638+ test_cases = [
2639+ {
2640+ "conv_type" : torch .nn .Conv1d ,
2641+ "bn_type" : torch .nn .BatchNorm1d ,
2642+ "example_input" : (torch .randn (1 , 3 , 5 ),),
2643+ "conv_op" : torch .ops .aten .conv1d .padding ,
2644+ },
2645+ {
2646+ "conv_type" : torch .nn .Conv2d ,
2647+ "bn_type" : torch .nn .BatchNorm2d ,
2648+ "example_input" : (torch .randn (1 , 3 , 5 , 5 ),),
2649+ "conv_op" : torch .ops .aten .conv2d .padding ,
2650+ },
2651+ {
2652+ "conv_type" : torch .nn .Conv3d ,
2653+ "bn_type" : torch .nn .BatchNorm3d ,
2654+ "example_input" : (torch .randn (1 , 3 , 5 , 5 , 5 ),),
2655+ "conv_op" : torch .ops .aten .conv3d .padding ,
2656+ },
2657+ ]
2658+
2659+ for test_case in test_cases :
2660+ with self .subTest (conv_type = test_case ["conv_type" ].__name__ ):
2661+
2662+ class M (torch .nn .Module ):
2663+ def __init__ (self ):
2664+ super ().__init__ ()
2665+ self .conv = test_case ["conv_type" ](3 , 3 , 3 , padding = "same" )
2666+ self .bn = test_case ["bn_type" ](3 )
2667+
2668+ def forward (self , x ):
2669+ return torch .nn .functional .relu (self .bn (self .conv (x )))
2670+
2671+ node_occurrence = {
2672+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2673+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2674+ }
2675+ node_list = [
2676+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2677+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2678+ test_case ["conv_op" ],
2679+ torch .ops .aten .relu .default ,
2680+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2681+ ]
2682+
2683+ model = M ().eval ()
2684+ self ._test_quantizer (
2685+ model ,
2686+ test_case ["example_input" ],
2687+ BackendAQuantizer (),
2688+ node_occurrence ,
2689+ node_list ,
2690+ )
2691+
25742692 def test_multi_users_without_output_observer (self ):
25752693 """
25762694 Test the case in which a node is used by multiple users,
0 commit comments