@@ -2385,6 +2385,192 @@ def validate(self, model: torch.fx.GraphModule) -> None:
2385
2385
node_list ,
2386
2386
)
2387
2387
2388
+ def test_conv3d_bn_relu (self ):
2389
+ class BackendAQuantizer (Quantizer ):
2390
+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2391
+ act_qspec = QuantizationSpec (
2392
+ dtype = torch .uint8 ,
2393
+ quant_min = 0 ,
2394
+ quant_max = 255 ,
2395
+ qscheme = torch .per_tensor_affine ,
2396
+ is_dynamic = False ,
2397
+ observer_or_fake_quant_ctr = observer .default_observer ,
2398
+ )
2399
+ weight_qspec = QuantizationSpec (
2400
+ dtype = torch .int8 ,
2401
+ quant_min = - 128 ,
2402
+ quant_max = 127 ,
2403
+ qscheme = torch .per_tensor_affine ,
2404
+ is_dynamic = False ,
2405
+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2406
+ )
2407
+ bias_qspec = QuantizationSpec (
2408
+ dtype = torch .float32 ,
2409
+ is_dynamic = False ,
2410
+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2411
+ )
2412
+ # conv_transpose + bn is fused automatically in PTQ (not configurable)
2413
+ # so we just need to annotate conv + relu for conv + bn + relu
2414
+ # pattern
2415
+ for n in model .graph .nodes :
2416
+ if (
2417
+ n .op != "call_function"
2418
+ or n .target != torch .ops .aten .relu .default
2419
+ ):
2420
+ continue
2421
+ relu_node = n
2422
+ n = n .args [0 ]
2423
+ if (
2424
+ n .op != "call_function"
2425
+ and n .target != torch .ops .aten .conv3d .input
2426
+ ):
2427
+ continue
2428
+ conv_t_node = n
2429
+ input_act = conv_t_node .args [0 ]
2430
+ weight = conv_t_node .args [1 ]
2431
+ bias = conv_t_node .args [2 ]
2432
+ conv_t_node .meta ["quantization_annotation" ] = (
2433
+ QuantizationAnnotation (
2434
+ input_qspec_map = {
2435
+ input_act : act_qspec ,
2436
+ weight : weight_qspec ,
2437
+ bias : bias_qspec ,
2438
+ },
2439
+ _annotated = True ,
2440
+ )
2441
+ )
2442
+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2443
+ output_qspec = act_qspec ,
2444
+ _annotated = True ,
2445
+ )
2446
+
2447
+ def validate (self , model : torch .fx .GraphModule ) -> None :
2448
+ pass
2449
+
2450
+ class M (torch .nn .Module ):
2451
+ def __init__ (self ):
2452
+ super ().__init__ ()
2453
+ self .conv = torch .nn .Conv3d (2 , 2 , 3 , padding = 1 )
2454
+ self .bn = torch .nn .BatchNorm3d (2 )
2455
+
2456
+ def forward (self , x ):
2457
+ return torch .nn .functional .relu (self .bn (self .conv (x )))
2458
+
2459
+ example_inputs = (torch .randn (1 , 2 , 2 , 5 , 5 ),)
2460
+ node_occurrence = {
2461
+ # two for input of the first conv, one for output for the first conv
2462
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2463
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2464
+ }
2465
+ node_list = [
2466
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2467
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2468
+ torch .ops .aten .conv3d .default ,
2469
+ torch .ops .aten .relu .default ,
2470
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2471
+ ]
2472
+ model = M ().eval ()
2473
+ self ._test_quantizer (
2474
+ model ,
2475
+ example_inputs ,
2476
+ BackendAQuantizer (),
2477
+ node_occurrence ,
2478
+ node_list ,
2479
+ )
2480
+
2481
+ def test_conv_transpose3d_bn_relu (self ):
2482
+ class BackendAQuantizer (Quantizer ):
2483
+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2484
+ act_qspec = QuantizationSpec (
2485
+ dtype = torch .uint8 ,
2486
+ quant_min = 0 ,
2487
+ quant_max = 255 ,
2488
+ qscheme = torch .per_tensor_affine ,
2489
+ is_dynamic = False ,
2490
+ observer_or_fake_quant_ctr = observer .default_observer ,
2491
+ )
2492
+ weight_qspec = QuantizationSpec (
2493
+ dtype = torch .int8 ,
2494
+ quant_min = - 128 ,
2495
+ quant_max = 127 ,
2496
+ qscheme = torch .per_tensor_affine ,
2497
+ is_dynamic = False ,
2498
+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2499
+ )
2500
+ bias_qspec = QuantizationSpec (
2501
+ dtype = torch .float32 ,
2502
+ is_dynamic = False ,
2503
+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2504
+ )
2505
+ # conv_transpose + bn is fused automatically in PTQ (not configurable)
2506
+ # so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
2507
+ # pattern
2508
+ for n in model .graph .nodes :
2509
+ if (
2510
+ n .op != "call_function"
2511
+ or n .target != torch .ops .aten .relu .default
2512
+ ):
2513
+ continue
2514
+ relu_node = n
2515
+ n = n .args [0 ]
2516
+ if (
2517
+ n .op != "call_function"
2518
+ and n .target != torch .ops .aten .conv_transposed3d .input
2519
+ ):
2520
+ continue
2521
+ conv_t_node = n
2522
+ input_act = conv_t_node .args [0 ]
2523
+ weight = conv_t_node .args [1 ]
2524
+ bias = conv_t_node .args [2 ]
2525
+ conv_t_node .meta ["quantization_annotation" ] = (
2526
+ QuantizationAnnotation (
2527
+ input_qspec_map = {
2528
+ input_act : act_qspec ,
2529
+ weight : weight_qspec ,
2530
+ bias : bias_qspec ,
2531
+ },
2532
+ _annotated = True ,
2533
+ )
2534
+ )
2535
+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2536
+ output_qspec = act_qspec ,
2537
+ _annotated = True ,
2538
+ )
2539
+
2540
+ def validate (self , model : torch .fx .GraphModule ) -> None :
2541
+ pass
2542
+
2543
+ class M (torch .nn .Module ):
2544
+ def __init__ (self ):
2545
+ super ().__init__ ()
2546
+ self .conv_t = torch .nn .ConvTranspose3d (2 , 2 , 3 , padding = 1 )
2547
+ self .bn = torch .nn .BatchNorm3d (2 )
2548
+
2549
+ def forward (self , x ):
2550
+ return torch .nn .functional .relu (self .bn (self .conv_t (x )))
2551
+
2552
+ example_inputs = (torch .randn (1 , 2 , 2 , 5 , 5 ),)
2553
+ node_occurrence = {
2554
+ # two for input of the first conv, one for output for the first conv
2555
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2556
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2557
+ }
2558
+ node_list = [
2559
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2560
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2561
+ torch .ops .aten .conv_transpose3d .input ,
2562
+ torch .ops .aten .relu .default ,
2563
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2564
+ ]
2565
+ model = M ().eval ()
2566
+ self ._test_quantizer (
2567
+ model ,
2568
+ example_inputs ,
2569
+ BackendAQuantizer (),
2570
+ node_occurrence ,
2571
+ node_list ,
2572
+ )
2573
+
2388
2574
def test_multi_users_without_output_observer (self ):
2389
2575
"""
2390
2576
Test the case in which a node is used by multiple users,
0 commit comments