@@ -2004,9 +2004,15 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
20042004
20052005 # Unpack attention weights if needed
20062006 self .make_attention_unpacked (layer_id , attention , root_input , ** kwargs )
2007+
2008+ # Get dtype used for MatMul ops
2009+ q_dtype = getattr (attention .q_proj , "bits" , None ) or getattr (attention .q_proj .weight , "dtype" , None )
2010+ k_dtype = getattr (attention .k_proj , "bits" , None ) or getattr (attention .k_proj .weight , "dtype" , None )
2011+ v_dtype = getattr (attention .v_proj , "bits" , None ) or getattr (attention .v_proj .weight , "dtype" , None )
2012+ all_dtype_equal = q_dtype == k_dtype == v_dtype
20072013
20082014 # Make MatMul nodes
2009- if self .attention_attrs ["use_packed_matmul" ]:
2015+ if self .attention_attrs ["use_packed_matmul" ] and all_dtype_equal :
20102016 # Combine 3 MatMuls into 1 packed MatMul
20112017 qkv_matmul_basename = f"/model/layers.{ layer_id } /attn/qkv_proj/MatMul"
20122018 qkv_matmul_name = self .make_packed_matmul (attention .q_proj , attention .k_proj , attention .v_proj , qkv_matmul_basename , root_input )
@@ -2028,7 +2034,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
20282034 v_bias_exists = attention .v_proj .bias is not None and torch .count_nonzero (attention .v_proj .bias ) > 0
20292035 any_bias_exists = q_bias_exists or k_bias_exists or v_bias_exists
20302036
2031- if self .attention_attrs ["use_packed_matmul" ] and any_bias_exists :
2037+ if self .attention_attrs ["use_packed_matmul" ] and any_bias_exists and all_dtype_equal :
20322038 # Combine 3 Adds into 1 packed Add
20332039 qkv_add_name = f"/model/layers.{ layer_id } /attn/qkv_proj/Add"
20342040 self .make_packed_add (attention .q_proj .bias , attention .k_proj .bias , attention .v_proj .bias , qkv_add_name , root_input = self .attention_attrs ["q_path" ])
0 commit comments