@@ -285,10 +285,10 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
285285
286286 self .int4_tied_embeddings = config .tie_word_embeddings if hasattr (config , "tie_word_embeddings" ) and config .tie_word_embeddings is not None else False
287287 self .int4_tied_embeddings = extra_options .get ("int4_tied_embeddings" , self .int4_tied_embeddings )
288- self .int8_lm_head = extra_options .get ("int4_algo_config" , "default" ) in {"k_quant_mixed" , "k_quant_last" }
289- if not self .int8_lm_head and extra_options .get ("int4_algo_config" , "default" ) != "rtn" :
288+ self .int8_lm_head = extra_options .get ("int4_algo_config" , "default" ) in {"k_quant_mixed" , "k_quant_last" , "rtn_last" }
289+ if not self .int8_lm_head and extra_options .get ("int4_algo_config" , "default" ) not in { "rtn" , "k_quant" } :
290290 # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
291- # tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn on 4bit
291+ # tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn&k_quant on 4bit
292292 self .int4_tied_embeddings = False
293293
294294 def to_str_dtype (self , dtype : ir .DataType ) -> str :
@@ -489,28 +489,35 @@ def make_int4_algo_config(self, quant_method: str):
489489 customized_weight_config = {}
490490 int4_algo_config = None
491491
492- if quant_method == "rtn" :
493- int4_algo_config = RTNWeightOnlyQuantConfig ()
492+ if quant_method in {"rtn" , "rtn_last" }:
493+ if quant_method == "rtn" :
494+ int4_algo_config = RTNWeightOnlyQuantConfig ()
495+ elif quant_method == "rtn_last" :
496+ customized_weight_config ["/lm_head/MatMul" ] = {"bits" : 8 }
497+ int4_algo_config = RTNWeightOnlyQuantConfig (customized_weight_config = customized_weight_config )
494498
495- elif quant_method in {"k_quant_mixed" , "k_quant_last" }:
499+ elif quant_method in {"k_quant" , " k_quant_mixed" , "k_quant_last" }:
496500 from onnxruntime .quantization .matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig
497501
498- if quant_method == "k_quant_mixed" :
499- # k_quant_mixed is from llama.cpp.
500- # Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
501- # We also consider some MatMuls are more senstive to quantization than other MatMuls.
502- layers_to_exclude = [
503- i
504- for i in range (self .num_layers )
505- if i < self .num_layers / 8 or i >= 7 * self .num_layers / 8 or (i - (round )(self .num_layers / 8 )) % 3 == 2
506- ]
507- for i in layers_to_exclude :
508- customized_weight_config ["/model/layers." + str (i ) + "/attn/qkv_proj/MatMul" ] = {"bits" : 8 }
509- customized_weight_config ["/model/layers." + str (i ) + "/attn/v_proj/MatMul" ] = {"bits" : 8 }
510- customized_weight_config ["/model/layers." + str (i ) + "/mlp/down_proj/MatMul" ] = {"bits" : 8 }
511-
512- customized_weight_config ["/lm_head/MatMul" ] = {"bits" : 8 }
513- int4_algo_config = KQuantWeightOnlyQuantConfig (customized_weight_config = customized_weight_config )
502+ if quant_method == "k_quant" :
503+ int4_algo_config = KQuantWeightOnlyQuantConfig ()
504+ else :
505+ if quant_method == "k_quant_mixed" :
506+ # k_quant_mixed is from llama.cpp.
507+ # Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
508+ # We also consider some MatMuls are more senstive to quantization than other MatMuls.
509+ layers_to_exclude = [
510+ i
511+ for i in range (self .num_layers )
512+ if i < self .num_layers / 8 or i >= 7 * self .num_layers / 8 or (i - (round )(self .num_layers / 8 )) % 3 == 2
513+ ]
514+ for i in layers_to_exclude :
515+ customized_weight_config ["/model/layers." + str (i ) + "/attn/qkv_proj/MatMul" ] = {"bits" : 8 }
516+ customized_weight_config ["/model/layers." + str (i ) + "/attn/v_proj/MatMul" ] = {"bits" : 8 }
517+ customized_weight_config ["/model/layers." + str (i ) + "/mlp/down_proj/MatMul" ] = {"bits" : 8 }
518+
519+ customized_weight_config ["/lm_head/MatMul" ] = {"bits" : 8 }
520+ int4_algo_config = KQuantWeightOnlyQuantConfig (customized_weight_config = customized_weight_config )
514521
515522 return int4_algo_config
516523
0 commit comments