@@ -1964,3 +1964,99 @@ def forward(self, x):
1964
1964
GroupNormModule (num_groups , num_channels ),
1965
1965
sample_inputs ,
1966
1966
)
1967
+
1968
+ def test_vulkan_backend_full_quantization_workflow (self ):
1969
+ class FullQuantizationWorkflowModule (torch .nn .Module ):
1970
+ def __init__ (self ):
1971
+ super ().__init__ ()
1972
+
1973
+ def forward (self , x ):
1974
+ # Step 1: Choose quantization parameters per tensor
1975
+ scale , zero_point = (
1976
+ torch .ops .quantized_decomposed .choose_qparams .tensor (
1977
+ x ,
1978
+ quant_min = - 2147483648 , # int32 min
1979
+ quant_max = 2147483647 , # int32 max
1980
+ eps = 1e-5 ,
1981
+ dtype = torch .int32 ,
1982
+ )
1983
+ )
1984
+
1985
+ # Step 2: Quantize using the calculated parameters
1986
+ quantized = torch .ops .quantized_decomposed .quantize_per_tensor .tensor (
1987
+ x ,
1988
+ scale ,
1989
+ zero_point ,
1990
+ quant_min = - 2147483648 , # int32 min
1991
+ quant_max = 2147483647 , # int32 max
1992
+ dtype = torch .int32 ,
1993
+ )
1994
+
1995
+ # Step 3: Dequantize back to float
1996
+ dequantized = (
1997
+ torch .ops .quantized_decomposed .dequantize_per_tensor .tensor (
1998
+ quantized ,
1999
+ scale ,
2000
+ zero_point ,
2001
+ quant_min = - 2147483648 , # int32 min
2002
+ quant_max = 2147483647 , # int32 max
2003
+ dtype = torch .int32 ,
2004
+ )
2005
+ )
2006
+
2007
+ return dequantized
2008
+
2009
+ full_workflow_module = FullQuantizationWorkflowModule ()
2010
+ sample_inputs = (torch .rand (size = (2 , 3 , 4 ), dtype = torch .float32 ),)
2011
+
2012
+ # Use higher tolerance since quantization introduces some error
2013
+ self .lower_module_and_test_output (
2014
+ full_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3
2015
+ )
2016
+
2017
+ def test_vulkan_backend_full_per_token_quantization_workflow (self ):
2018
+ class FullPerTokenQuantizationWorkflowModule (torch .nn .Module ):
2019
+ def __init__ (self ):
2020
+ super ().__init__ ()
2021
+
2022
+ def forward (self , x ):
2023
+ # Step 1: Choose quantization parameters per token
2024
+ scale , zero_point = (
2025
+ torch .ops .quantized_decomposed .choose_qparams_per_token_asymmetric .default (
2026
+ x ,
2027
+ dtype = torch .int32 ,
2028
+ )
2029
+ )
2030
+
2031
+ # Step 2: Quantize using the calculated parameters per token
2032
+ quantized = torch .ops .quantized_decomposed .quantize_per_token .default (
2033
+ x ,
2034
+ scale ,
2035
+ zero_point ,
2036
+ quant_min = - 2147483648 , # int32 min
2037
+ quant_max = 2147483647 , # int32 max
2038
+ dtype = torch .int32 ,
2039
+ )
2040
+
2041
+ # Step 3: Dequantize back to float per token
2042
+ dequantized = (
2043
+ torch .ops .quantized_decomposed .dequantize_per_token .default (
2044
+ quantized ,
2045
+ scale ,
2046
+ zero_point ,
2047
+ quant_min = - 2147483648 , # int32 min
2048
+ quant_max = 2147483647 , # int32 max
2049
+ dtype = torch .int32 ,
2050
+ output_dtype = torch .float32 ,
2051
+ )
2052
+ )
2053
+
2054
+ return dequantized
2055
+
2056
+ full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule ()
2057
+ sample_inputs = (torch .rand (size = (6 , 4 ), dtype = torch .float32 ),)
2058
+
2059
+ # Use higher tolerance since quantization introduces some error
2060
+ self .lower_module_and_test_output (
2061
+ full_per_token_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3
2062
+ )
0 commit comments