Skip to content

Commit b75bb1e

Browse files
committed
Added rtn mixed precision quant options and managed tied 4b embs conditions
1 parent 958574d commit b75bb1e

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

src/python/py/models/builder.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,10 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
285285

286286
self.int4_tied_embeddings = config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False
287287
self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", self.int4_tied_embeddings)
288-
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"}
289-
if not self.int8_lm_head and extra_options.get("int4_algo_config", "default") != "rtn":
288+
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"}
289+
if not self.int8_lm_head and extra_options.get("int4_algo_config", "default") not in {"rtn", "k_quant"}:
290290
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
291-
# tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn on 4bit
291+
# tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn&k_quant on 4bit
292292
self.int4_tied_embeddings = False
293293

294294
def to_str_dtype(self, dtype: ir.DataType) -> str:
@@ -489,28 +489,35 @@ def make_int4_algo_config(self, quant_method: str):
489489
customized_weight_config = {}
490490
int4_algo_config = None
491491

492-
if quant_method == "rtn":
493-
int4_algo_config = RTNWeightOnlyQuantConfig()
492+
if quant_method in {"rtn", "rtn_last"}:
493+
if quant_method == "rtn":
494+
int4_algo_config = RTNWeightOnlyQuantConfig()
495+
elif quant_method == "rtn_last":
496+
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
497+
int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
494498

495-
elif quant_method in {"k_quant_mixed", "k_quant_last"}:
499+
elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}:
496500
from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig
497501

498-
if quant_method == "k_quant_mixed":
499-
# k_quant_mixed is from llama.cpp.
500-
# Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
501-
# We also consider some MatMuls are more senstive to quantization than other MatMuls.
502-
layers_to_exclude = [
503-
i
504-
for i in range(self.num_layers)
505-
if i < self.num_layers / 8 or i >= 7 * self.num_layers / 8 or (i - (round)(self.num_layers / 8)) % 3 == 2
506-
]
507-
for i in layers_to_exclude:
508-
customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8}
509-
customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
510-
customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}
511-
512-
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
513-
int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
502+
if quant_method == "k_quant":
503+
int4_algo_config = KQuantWeightOnlyQuantConfig()
504+
else:
505+
if quant_method == "k_quant_mixed":
506+
# k_quant_mixed is from llama.cpp.
507+
# Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
508+
# We also consider some MatMuls are more senstive to quantization than other MatMuls.
509+
layers_to_exclude = [
510+
i
511+
for i in range(self.num_layers)
512+
if i < self.num_layers / 8 or i >= 7 * self.num_layers / 8 or (i - (round)(self.num_layers / 8)) % 3 == 2
513+
]
514+
for i in layers_to_exclude:
515+
customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8}
516+
customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
517+
customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}
518+
519+
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
520+
int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
514521

515522
return int4_algo_config
516523

0 commit comments

Comments
 (0)