@@ -911,14 +911,12 @@ def shard_low_precision_checkpoint(
911911
912912 """
913913 assert tp_grain_size % 8 == 0 , "tp_grain_size must be a multiple of 8"
914- if isinstance (model_config , dict ):
915- num_heads = model_config ["num_attention_heads" ]
916- if "num_key_value_heads" in model_config :
917- num_heads = model_config ["num_key_value_heads" ]
918- else :
919- num_heads = model_config .num_attention_heads
920- if "num_key_value_heads" in model_config :
921- num_heads = model_config .num_key_value_heads
914+ if not isinstance (model_config , dict ):
915+ model_config = model_config .to_dict ()
916+ num_heads = model_config ["num_attention_heads" ]
917+ num_kv_heads = num_heads
918+ if "num_key_value_heads" in model_config :
919+ num_kv_heads = model_config ["num_key_value_heads" ]
922920 local_rank = rank
923921
924922 mha_layers_split_by_N = [
@@ -928,6 +926,9 @@ def shard_low_precision_checkpoint(
928926 "q_b_proj" ,
929927 "kv_b_proj" ,
930928 ]
929+ qkv_proj_layers = [
930+ "qkv_proj" ,
931+ ]
931932 # mlp is split with grain size = tp_grain_size
932933 mlp_layers_split_by_N = [
933934 "gate_proj" ,
@@ -938,6 +939,9 @@ def shard_low_precision_checkpoint(
938939 "w1" ,
939940 "w3" ,
940941 ]
942+ gate_up_proj_layers = [
943+ "gate_up_proj" ,
944+ ]
941945 mha_layers_split_by_K = [
942946 "o_proj" ,
943947 "out_proj" ,
@@ -952,20 +956,28 @@ def shard_low_precision_checkpoint(
952956 "w2" ,
953957 ]
954958 lm_head_layers = ["lm_head" ] # split by K but not quantized
959+
960+ def _key_belongs_to (key , layer_group ):
961+ key_split = key .split ("." )
962+ for layer in layer_group :
963+ if layer in key_split :
964+ return True
965+ return False
966+
955967 low_precision_checkpoint_dict = low_precision_checkpoint .copy ()
956968 head_range = [0 ]
957- head_per_rank = num_heads // world_size
969+ head_per_rank = num_kv_heads // world_size
958970 for i in range (0 , world_size ):
959971 head_this_rank = head_per_rank
960- if i < num_heads % world_size :
972+ if i < num_kv_heads % world_size :
961973 head_this_rank += 1
962974 head_range .append (head_range [- 1 ] + head_this_rank )
963975 for key in low_precision_checkpoint .keys ():
964976 q_head_start = head_range [rank ]
965977 q_head_end = q_head_start + (head_range [rank + 1 ] - head_range [rank ])
966978 if "bias" in key :
967979 continue
968- if any ( substring in key for substring in mha_layers_split_by_N ):
980+ if _key_belongs_to ( key , mha_layers_split_by_N ):
969981 data = low_precision_checkpoint_dict [key ]
970982 if quantization_method == "awq" :
971983 # qweight shape: [K, N // 8]
@@ -1046,7 +1058,91 @@ def shard_low_precision_checkpoint(
10461058 ].contiguous ()
10471059 else :
10481060 raise AssertionError (f"{ quantization_method } is not supported yet." )
1049- elif any (substring in key for substring in mlp_layers_split_by_N ):
1061+ elif _key_belongs_to (key , qkv_proj_layers ):
1062+ # need to split q, k and v proj then shard them separately
1063+ # finally concat them together
1064+ # mha layer split by N
1065+ data = low_precision_checkpoint_dict [key ]
1066+ hidden_size = model_config ["hidden_size" ]
1067+ head_dim = hidden_size // num_heads
1068+ if quantization_method == "awq" :
1069+ # qweight shape: [K, N // 8]
1070+ # scales shape: [K // G, N]
1071+ # qzeros shape: [K // G, N // 8]
1072+ N_pack_factor = 1 if "scales" in key else 8
1073+ N = data .shape [- 1 ] * N_pack_factor
1074+ q_pos = N - 2 * num_kv_heads * head_dim
1075+ k_pos = q_pos + num_kv_heads * head_dim
1076+ v_pos = k_pos + num_kv_heads * head_dim
1077+ q_pos //= N_pack_factor
1078+ k_pos //= N_pack_factor
1079+ v_pos //= N_pack_factor
1080+ data_list = [
1081+ data [:, :q_pos ],
1082+ data [:, q_pos :k_pos ],
1083+ data [:, k_pos :v_pos ],
1084+ ]
1085+ for i in range (len (data_list )):
1086+ data = data_list [i ].contiguous ()
1087+ if data .shape [- 1 ] % head_range [- 1 ] == 0 :
1088+ dim = data .shape [- 1 ] // head_range [- 1 ]
1089+ else :
1090+ assert data .shape [- 1 ] % world_size == 0
1091+ dim = data .shape [- 1 ] // world_size
1092+ q_head_start = local_rank
1093+ q_head_end = local_rank + 1
1094+ data_list [i ] = data [
1095+ :, q_head_start * dim : q_head_end * dim
1096+ ].contiguous ()
1097+ low_precision_checkpoint_dict [key ] = torch .cat (
1098+ data_list , dim = - 1
1099+ ).contiguous ()
1100+ elif quantization_method == "gptq" or (
1101+ quantization_method == "rtn" and bits == 4
1102+ ):
1103+ # qweight shape: [K // 8, N]
1104+ # scales shape: [K // G, N]
1105+ # qzeros shape: [K // G, N // 8]
1106+ # g_idx shape: [K]
1107+ data_list = []
1108+ if "g_idx" not in key :
1109+ N_pack_factor = 8 if "qzeros" in key else 1
1110+ N = data .shape [- 1 ] * N_pack_factor
1111+ q_pos = N - 2 * num_kv_heads * head_dim
1112+ k_pos = q_pos + num_kv_heads * head_dim
1113+ v_pos = k_pos + num_kv_heads * head_dim
1114+ q_pos //= N_pack_factor
1115+ k_pos //= N_pack_factor
1116+ v_pos //= N_pack_factor
1117+ data_list = [
1118+ data [:, :q_pos ],
1119+ data [:, q_pos :k_pos ],
1120+ data [:, k_pos :v_pos ],
1121+ ]
1122+ for i in range (len (data_list )):
1123+ if "g_idx" in key :
1124+ continue
1125+ data = data_list [i ]
1126+ if data .shape [- 1 ] % head_range [- 1 ] == 0 :
1127+ dim = data .shape [- 1 ] // head_range [- 1 ]
1128+ else :
1129+ assert data .shape [- 1 ] % world_size == 0
1130+ dim = data .shape [- 1 ] // world_size
1131+ q_head_start = local_rank
1132+ q_head_end = local_rank + 1
1133+ data_list [i ] = data [
1134+ :, q_head_start * dim : q_head_end * dim
1135+ ].contiguous ()
1136+ if "g_idx" in key :
1137+ if not desc_act :
1138+ low_precision_checkpoint_dict .pop (key )
1139+ else :
1140+ low_precision_checkpoint_dict [key ] = torch .cat (
1141+ data_list , dim = - 1
1142+ ).contiguous ()
1143+ else :
1144+ raise AssertionError (f"{ quantization_method } is not supported yet." )
1145+ elif _key_belongs_to (key , mlp_layers_split_by_N ):
10501146 data = low_precision_checkpoint_dict [key ]
10511147 if quantization_method == "awq" :
10521148 # qweight shape: [K, N // 8]
@@ -1183,7 +1279,95 @@ def shard_low_precision_checkpoint(
11831279 ].contiguous ()
11841280 else :
11851281 raise AssertionError (f"{ quantization_method } is not supported yet." )
1186- elif any (substring in key for substring in mha_layers_split_by_K ):
1282+ elif _key_belongs_to (key , gate_up_proj_layers ):
1283+ # need to split gate and up proj then shard them separately
1284+ # finally concat them together
1285+ # mlp layer split by N
1286+ data = low_precision_checkpoint_dict [key ]
1287+ if quantization_method == "awq" :
1288+ # qweight shape: [K, N // 8]
1289+ # scales shape: [K // G, N]
1290+ # qzeros shape: [K // G, N // 8]
1291+ data_list = list (data .chunk (2 , dim = - 1 ))
1292+ for i in range (len (data_list )):
1293+ data = data_list [i ].contiguous ()
1294+ if "scales" in key :
1295+ assert (
1296+ data .shape [1 ] % tp_grain_size == 0
1297+ ), "N must be divisible by tp_grain_size"
1298+ grains = data .shape [1 ] // tp_grain_size
1299+ dim = tp_grain_size
1300+ else :
1301+ assert (
1302+ data .shape [1 ] * 8
1303+ ) % tp_grain_size == 0 , "N must be divisible by tp_grain_size"
1304+ grains = data .shape [1 ] // (tp_grain_size // 8 )
1305+ dim = tp_grain_size // 8
1306+ grains_per_rank = grains // world_size
1307+ grains_rem = grains % world_size
1308+ grains_start = grains_per_rank * local_rank + min (
1309+ local_rank , grains_rem
1310+ )
1311+ grains_end = (
1312+ grains_start
1313+ + grains_per_rank
1314+ + (1 if local_rank < grains_rem else 0 )
1315+ )
1316+ data_list [i ] = data [
1317+ :, grains_start * dim : grains_end * dim
1318+ ].contiguous ()
1319+ low_precision_checkpoint_dict [key ] = torch .cat (
1320+ data_list , dim = - 1
1321+ ).contiguous ()
1322+ elif quantization_method == "gptq" or (
1323+ quantization_method == "rtn" and bits == 4
1324+ ):
1325+ # qweight shape: [K // 8, N]
1326+ # scales shape: [K // G, N]
1327+ # qzeros shape: [K // G, N // 8]
1328+ # g_idx shape: [K]
1329+ data_list = list (data .chunk (2 , dim = - 1 ))
1330+ for i in range (len (data_list )):
1331+ if "g_idx" in key :
1332+ continue
1333+ data = data_list [i ]
1334+ if "qzeros" in key :
1335+ assert (
1336+ data .shape [- 1 ] * 8
1337+ ) % tp_grain_size == 0 , "N must be divisible by tp_grain_size"
1338+ grains = data .shape [- 1 ] // (tp_grain_size // 8 )
1339+ dim = tp_grain_size // 8
1340+ elif "g_idx" not in key : # qweight, scales
1341+ assert (
1342+ data .shape [- 1 ] % tp_grain_size == 0
1343+ ), "N must be divisible by tp_grain_size"
1344+ grains = data .shape [- 1 ] // tp_grain_size
1345+ dim = tp_grain_size
1346+ grains_per_rank = grains // world_size
1347+ grains_rem = grains % world_size
1348+ grains_start = grains_per_rank * local_rank + min (
1349+ local_rank , grains_rem
1350+ )
1351+ grains_end = (
1352+ grains_start
1353+ + grains_per_rank
1354+ + (1 if local_rank < grains_rem else 0 )
1355+ )
1356+ data_list [i ] = data [
1357+ :, grains_start * dim : grains_end * dim
1358+ ].contiguous ()
1359+ if "g_idx" in key :
1360+ if not desc_act :
1361+ low_precision_checkpoint_dict .pop (key )
1362+ else :
1363+ low_precision_checkpoint_dict [key ] = torch .cat (
1364+ data_list , dim = - 1
1365+ ).contiguous ()
1366+ else :
1367+ raise AssertionError (f"{ quantization_method } is not supported yet." )
1368+ elif _key_belongs_to (key , mha_layers_split_by_K ):
1369+ if "bias" in key :
1370+ continue
11871371 data = low_precision_checkpoint_dict [key ]
11881372 if ("scales" in key or "qzeros" in key ) and data .shape [0 ] == 1 :
11891373 continue
@@ -1271,10 +1455,10 @@ def shard_low_precision_checkpoint(
12711455 q_head_end = local_rank + 1
12721456 low_precision_checkpoint_dict [key ] = data [
12731457 :, q_head_start * dim : q_head_end * dim
1274- ]. contiguous ()
1458+ ]
12751459 else :
12761460 raise AssertionError (f"{ quantization_method } is not supported yet." )
1277- elif any ( substring in key for substring in mlp_layers_split_by_K ):
1461+ elif _key_belongs_to ( key , mlp_layers_split_by_K ):
12781462 data = low_precision_checkpoint_dict [key ]
12791463 if ("scales" in key or "qzeros" in key ) and data .shape [0 ] == 1 :
12801464 continue
@@ -1424,10 +1608,10 @@ def shard_low_precision_checkpoint(
14241608 )
14251609 low_precision_checkpoint_dict [key ] = data [
14261610 :, grains_start * dim : grains_end * dim
1427- ]. contiguous ()
1611+ ]
14281612 else :
14291613 raise AssertionError (f"{ quantization_method } is not supported yet." )
1430- elif any ( substring in key for substring in lm_head_layers ):
1614+ elif _key_belongs_to ( key , lm_head_layers ):
14311615 # lm_head: [N, K] (not quantized)
14321616 # Same for all quantization methods
14331617 data = low_precision_checkpoint_dict [key ]
0 commit comments