Skip to content

Commit f6386c2

Browse files
author
Ronak Mahawar
committed
Mixed precision export support for gptq quantized model
1 parent ded6e97 commit f6386c2

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/python/py/models/builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,9 +2004,15 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
20042004

20052005
# Unpack attention weights if needed
20062006
self.make_attention_unpacked(layer_id, attention, root_input, **kwargs)
2007+
2008+
# Get dtype used for MatMul ops
2009+
q_dtype = getattr(attention.q_proj, "bits", None) or getattr(attention.q_proj.weight, "dtype", None)
2010+
k_dtype = getattr(attention.k_proj, "bits", None) or getattr(attention.k_proj.weight, "dtype", None)
2011+
v_dtype = getattr(attention.v_proj, "bits", None) or getattr(attention.v_proj.weight, "dtype", None)
2012+
all_dtype_equal = q_dtype == k_dtype == v_dtype
20072013

20082014
# Make MatMul nodes
2009-
if self.attention_attrs["use_packed_matmul"]:
2015+
if self.attention_attrs["use_packed_matmul"] and all_dtype_equal:
20102016
# Combine 3 MatMuls into 1 packed MatMul
20112017
qkv_matmul_basename = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul"
20122018
qkv_matmul_name = self.make_packed_matmul(attention.q_proj, attention.k_proj, attention.v_proj, qkv_matmul_basename, root_input)
@@ -2028,7 +2034,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
20282034
v_bias_exists = attention.v_proj.bias is not None and torch.count_nonzero(attention.v_proj.bias) > 0
20292035
any_bias_exists = q_bias_exists or k_bias_exists or v_bias_exists
20302036

2031-
if self.attention_attrs["use_packed_matmul"] and any_bias_exists:
2037+
if self.attention_attrs["use_packed_matmul"] and any_bias_exists and all_dtype_equal:
20322038
# Combine 3 Adds into 1 packed Add
20332039
qkv_add_name = f"/model/layers.{layer_id}/attn/qkv_proj/Add"
20342040
self.make_packed_add(attention.q_proj.bias, attention.k_proj.bias, attention.v_proj.bias, qkv_add_name, root_input=self.attention_attrs["q_path"])

src/python/py/models/quantized_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,22 @@ def __init__(self, module):
863863
self.pack_qzeros(temp_module)
864864
module.qzeros = temp_module.qzeros
865865

866+
def _load_quant_config(self, quant_attrs):
867+
super()._load_quant_config(quant_attrs)
868+
self.overrides = quant_attrs["config"].get("dynamic", {})
869+
870+
def get_overrides(self, layer_name):
871+
for pattern, overrides in self.overrides.items():
872+
if re.match(pattern.removeprefix("+:"), layer_name):
873+
return overrides
874+
return {}
875+
876+
def get_layer_bits(self, layer_name):
877+
return self.get_overrides(layer_name).get("bits", self.global_bits)
878+
879+
def get_layer_group_size(self, layer_name):
880+
return self.get_overrides(layer_name).get("group_size", self.global_group_size)
881+
866882
class QuarkModel(QuantizedModel):
867883
def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers):
868884
super().__init__(quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers)

0 commit comments

Comments
 (0)