Skip to content

Commit ea628af

Browse files
Port a few commits from release/2.7 to release/2.8 (#3769)
* WOQ: reduce memory usage for DeepSeek R1 DA8W8 with TP (#3645) * WOQ: fix phi3 issue with latest DeepSpeed and TP=6 (#3655) * LLM BKC: Always set config.use_cache to true (#3657) Co-authored-by: Chunyuan WU <[email protected]> --------- Co-authored-by: Chunyuan WU <[email protected]>
1 parent 6e23534 commit ea628af

File tree

5 files changed

+209
-25
lines changed

5 files changed

+209
-25
lines changed

csrc/cpu/aten/kernels/WoqUtilKrnl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ at::Tensor qlinear_woq_pack(
3636
size_t block_k,
3737
int64_t lowp_mode,
3838
int64_t weight_format) {
39-
TLA_ASSERT(qw.is_contiguous(), "qw must be contiguous");
4039
bool is_4bit_flag = is_4bit(qw_type);
4140
auto sizes = qw.sizes();
41+
auto strides = qw.strides();
4242
auto N = sizes[0];
4343
auto K = is_4bit_flag ? sizes[1] * 2 : sizes[1];
4444
if (weight_format == GPTQ_WEIGHT_FORMAT) {
@@ -66,6 +66,7 @@ at::Tensor qlinear_woq_pack(
6666
const int Nc = N / block_n;
6767
const int Kc = K / block_k;
6868
if (is_4bit_flag) {
69+
TORCH_CHECK(qw.is_contiguous(), "qw must be contiguous");
6970
auto result = at::empty(
7071
{Nc, Kc, block_k, block_n / 2}, qw.options().dtype(at::kByte));
7172
// Pack weight in [N,K] to [N/block_n, K/block_k, block_k, block_n]
@@ -228,7 +229,8 @@ at::Tensor qlinear_woq_pack(
228229
// Pack weight in [N,K] to [N/block_n, K/block_k, block_k, block_n]
229230
int8_t* src_data = (int8_t*)qw.data_ptr();
230231
int8_t* dst_data = (int8_t*)result.data_ptr();
231-
auto psrc = GetVLAPtr<int8_t>(src_data, {block_n, Kc, block_k});
232+
auto real_Kc = strides[0] / block_k;
233+
auto psrc = GetVLAPtr<int8_t>(src_data, {block_n, real_Kc, block_k});
232234
auto pdst = GetVLAPtr<int8_t>(dst_data, {Kc, block_k, block_n});
233235
auto pack_loop =
234236
ThreadedLoop<3>({{Nc}, {Kc}, {0, block_n, N_GROUP_SIZE, false}}, "ABc");

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,8 @@ def get_checkpoint_files(model_name_or_path):
421421
kv_cache_dtype = torch.float8_e5m2
422422
config.kv_cache_dtype = kv_cache_dtype
423423

424+
config.use_cache = True # For inference, it should always be True
425+
424426
# For DeepSeek models
425427
if not args.ipex_weight_only_quantization and args.ipex and args.dtype == "bfloat16":
426428
config.use_fused_moe = True
@@ -435,8 +437,6 @@ def get_checkpoint_files(model_name_or_path):
435437
config.max_seq_len = int(args.input_tokens) + int(args.max_new_tokens)
436438
if model_type == "whisper":
437439
config.text_max_length = config.max_source_positions + config.max_target_positions
438-
if model_type == "llava":
439-
config.use_cache = True
440440
if model_type == "jamba":
441441
config.use_mamba_kernels = False
442442
if not hasattr(config, "lm_head_generation"):

examples/cpu/llm/inference/single_instance/run_quantization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ def str_to_kwargs(s):
362362
args.config_file, torchscript=True, trust_remote_code=True
363363
)
364364

365+
config.use_cache = True # For inference, it should always be True
366+
365367
# For DeepSeek models
366368
if args.ipex_weight_only_quantization and args.weight_dtype == "INT8":
367369
config.use_fused_moe = True

intel_extension_for_pytorch/llm/utils.py

Lines changed: 201 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -911,14 +911,12 @@ def shard_low_precision_checkpoint(
911911
912912
"""
913913
assert tp_grain_size % 8 == 0, "tp_grain_size must be a multiple of 8"
914-
if isinstance(model_config, dict):
915-
num_heads = model_config["num_attention_heads"]
916-
if "num_key_value_heads" in model_config:
917-
num_heads = model_config["num_key_value_heads"]
918-
else:
919-
num_heads = model_config.num_attention_heads
920-
if "num_key_value_heads" in model_config:
921-
num_heads = model_config.num_key_value_heads
914+
if not isinstance(model_config, dict):
915+
model_config = model_config.to_dict()
916+
num_heads = model_config["num_attention_heads"]
917+
num_kv_heads = num_heads
918+
if "num_key_value_heads" in model_config:
919+
num_kv_heads = model_config["num_key_value_heads"]
922920
local_rank = rank
923921

924922
mha_layers_split_by_N = [
@@ -928,6 +926,9 @@ def shard_low_precision_checkpoint(
928926
"q_b_proj",
929927
"kv_b_proj",
930928
]
929+
qkv_proj_layers = [
930+
"qkv_proj",
931+
]
931932
# mlp is split with grain size = tp_grain_size
932933
mlp_layers_split_by_N = [
933934
"gate_proj",
@@ -938,6 +939,9 @@ def shard_low_precision_checkpoint(
938939
"w1",
939940
"w3",
940941
]
942+
gate_up_proj_layers = [
943+
"gate_up_proj",
944+
]
941945
mha_layers_split_by_K = [
942946
"o_proj",
943947
"out_proj",
@@ -952,20 +956,28 @@ def shard_low_precision_checkpoint(
952956
"w2",
953957
]
954958
lm_head_layers = ["lm_head"] # split by K but not quantized
959+
960+
def _key_belongs_to(key, layer_group):
961+
key_split = key.split(".")
962+
for layer in layer_group:
963+
if layer in key_split:
964+
return True
965+
return False
966+
955967
low_precision_checkpoint_dict = low_precision_checkpoint.copy()
956968
head_range = [0]
957-
head_per_rank = num_heads // world_size
969+
head_per_rank = num_kv_heads // world_size
958970
for i in range(0, world_size):
959971
head_this_rank = head_per_rank
960-
if i < num_heads % world_size:
972+
if i < num_kv_heads % world_size:
961973
head_this_rank += 1
962974
head_range.append(head_range[-1] + head_this_rank)
963975
for key in low_precision_checkpoint.keys():
964976
q_head_start = head_range[rank]
965977
q_head_end = q_head_start + (head_range[rank + 1] - head_range[rank])
966978
if "bias" in key:
967979
continue
968-
if any(substring in key for substring in mha_layers_split_by_N):
980+
if _key_belongs_to(key, mha_layers_split_by_N):
969981
data = low_precision_checkpoint_dict[key]
970982
if quantization_method == "awq":
971983
# qweight shape: [K, N // 8]
@@ -1046,7 +1058,91 @@ def shard_low_precision_checkpoint(
10461058
].contiguous()
10471059
else:
10481060
raise AssertionError(f"{quantization_method} is not supported yet.")
1049-
elif any(substring in key for substring in mlp_layers_split_by_N):
1061+
elif _key_belongs_to(key, qkv_proj_layers):
1062+
# need to split q, k and v proj then shard them separately
1063+
# finally concat them together
1064+
# mha layer split by N
1065+
data = low_precision_checkpoint_dict[key]
1066+
hidden_size = model_config["hidden_size"]
1067+
head_dim = hidden_size // num_heads
1068+
if quantization_method == "awq":
1069+
# qweight shape: [K, N // 8]
1070+
# scales shape: [K // G, N]
1071+
# qzeros shape: [K // G, N // 8]
1072+
N_pack_factor = 1 if "scales" in key else 8
1073+
N = data.shape[-1] * N_pack_factor
1074+
q_pos = N - 2 * num_kv_heads * head_dim
1075+
k_pos = q_pos + num_kv_heads * head_dim
1076+
v_pos = k_pos + num_kv_heads * head_dim
1077+
q_pos //= N_pack_factor
1078+
k_pos //= N_pack_factor
1079+
v_pos //= N_pack_factor
1080+
data_list = [
1081+
data[:, :q_pos],
1082+
data[:, q_pos:k_pos],
1083+
data[:, k_pos:v_pos],
1084+
]
1085+
for i in range(len(data_list)):
1086+
data = data_list[i].contiguous()
1087+
if data.shape[-1] % head_range[-1] == 0:
1088+
dim = data.shape[-1] // head_range[-1]
1089+
else:
1090+
assert data.shape[-1] % world_size == 0
1091+
dim = data.shape[-1] // world_size
1092+
q_head_start = local_rank
1093+
q_head_end = local_rank + 1
1094+
data_list[i] = data[
1095+
:, q_head_start * dim : q_head_end * dim
1096+
].contiguous()
1097+
low_precision_checkpoint_dict[key] = torch.cat(
1098+
data_list, dim=-1
1099+
).contiguous()
1100+
elif quantization_method == "gptq" or (
1101+
quantization_method == "rtn" and bits == 4
1102+
):
1103+
# qweight shape: [K // 8, N]
1104+
# scales shape: [K // G, N]
1105+
# qzeros shape: [K // G, N // 8]
1106+
# g_idx shape: [K]
1107+
data_list = []
1108+
if "g_idx" not in key:
1109+
N_pack_factor = 8 if "qzeros" in key else 1
1110+
N = data.shape[-1] * N_pack_factor
1111+
q_pos = N - 2 * num_kv_heads * head_dim
1112+
k_pos = q_pos + num_kv_heads * head_dim
1113+
v_pos = k_pos + num_kv_heads * head_dim
1114+
q_pos //= N_pack_factor
1115+
k_pos //= N_pack_factor
1116+
v_pos //= N_pack_factor
1117+
data_list = [
1118+
data[:, :q_pos],
1119+
data[:, q_pos:k_pos],
1120+
data[:, k_pos:v_pos],
1121+
]
1122+
for i in range(len(data_list)):
1123+
if "g_idx" in key:
1124+
continue
1125+
data = data_list[i]
1126+
if data.shape[-1] % head_range[-1] == 0:
1127+
dim = data.shape[-1] // head_range[-1]
1128+
else:
1129+
assert data.shape[-1] % world_size == 0
1130+
dim = data.shape[-1] // world_size
1131+
q_head_start = local_rank
1132+
q_head_end = local_rank + 1
1133+
data_list[i] = data[
1134+
:, q_head_start * dim : q_head_end * dim
1135+
].contiguous()
1136+
if "g_idx" in key:
1137+
if not desc_act:
1138+
low_precision_checkpoint_dict.pop(key)
1139+
else:
1140+
low_precision_checkpoint_dict[key] = torch.cat(
1141+
data_list, dim=-1
1142+
).contiguous()
1143+
else:
1144+
raise AssertionError(f"{quantization_method} is not supported yet.")
1145+
elif _key_belongs_to(key, mlp_layers_split_by_N):
10501146
data = low_precision_checkpoint_dict[key]
10511147
if quantization_method == "awq":
10521148
# qweight shape: [K, N // 8]
@@ -1183,7 +1279,95 @@ def shard_low_precision_checkpoint(
11831279
].contiguous()
11841280
else:
11851281
raise AssertionError(f"{quantization_method} is not supported yet.")
1186-
elif any(substring in key for substring in mha_layers_split_by_K):
1282+
elif _key_belongs_to(key, gate_up_proj_layers):
1283+
# need to split gate and up proj then shard them separately
1284+
# finally concat them together
1285+
# mlp layer split by N
1286+
data = low_precision_checkpoint_dict[key]
1287+
if quantization_method == "awq":
1288+
# qweight shape: [K, N // 8]
1289+
# scales shape: [K // G, N]
1290+
# qzeros shape: [K // G, N // 8]
1291+
data_list = list(data.chunk(2, dim=-1))
1292+
for i in range(len(data_list)):
1293+
data = data_list[i].contiguous()
1294+
if "scales" in key:
1295+
assert (
1296+
data.shape[1] % tp_grain_size == 0
1297+
), "N must be divisible by tp_grain_size"
1298+
grains = data.shape[1] // tp_grain_size
1299+
dim = tp_grain_size
1300+
else:
1301+
assert (
1302+
data.shape[1] * 8
1303+
) % tp_grain_size == 0, "N must be divisible by tp_grain_size"
1304+
grains = data.shape[1] // (tp_grain_size // 8)
1305+
dim = tp_grain_size // 8
1306+
grains_per_rank = grains // world_size
1307+
grains_rem = grains % world_size
1308+
grains_start = grains_per_rank * local_rank + min(
1309+
local_rank, grains_rem
1310+
)
1311+
grains_end = (
1312+
grains_start
1313+
+ grains_per_rank
1314+
+ (1 if local_rank < grains_rem else 0)
1315+
)
1316+
data_list[i] = data[
1317+
:, grains_start * dim : grains_end * dim
1318+
].contiguous()
1319+
low_precision_checkpoint_dict[key] = torch.cat(
1320+
data_list, dim=-1
1321+
).contiguous()
1322+
elif quantization_method == "gptq" or (
1323+
quantization_method == "rtn" and bits == 4
1324+
):
1325+
# qweight shape: [K // 8, N]
1326+
# scales shape: [K // G, N]
1327+
# qzeros shape: [K // G, N // 8]
1328+
# g_idx shape: [K]
1329+
data_list = list(data.chunk(2, dim=-1))
1330+
for i in range(len(data_list)):
1331+
if "g_idx" in key:
1332+
continue
1333+
data = data_list[i]
1334+
if "qzeros" in key:
1335+
assert (
1336+
data.shape[-1] * 8
1337+
) % tp_grain_size == 0, "N must be divisible by tp_grain_size"
1338+
grains = data.shape[-1] // (tp_grain_size // 8)
1339+
dim = tp_grain_size // 8
1340+
elif "g_idx" not in key: # qweight, scales
1341+
assert (
1342+
data.shape[-1] % tp_grain_size == 0
1343+
), "N must be divisible by tp_grain_size"
1344+
grains = data.shape[-1] // tp_grain_size
1345+
dim = tp_grain_size
1346+
grains_per_rank = grains // world_size
1347+
grains_rem = grains % world_size
1348+
grains_start = grains_per_rank * local_rank + min(
1349+
local_rank, grains_rem
1350+
)
1351+
grains_end = (
1352+
grains_start
1353+
+ grains_per_rank
1354+
+ (1 if local_rank < grains_rem else 0)
1355+
)
1356+
data_list[i] = data[
1357+
:, grains_start * dim : grains_end * dim
1358+
].contiguous()
1359+
if "g_idx" in key:
1360+
if not desc_act:
1361+
low_precision_checkpoint_dict.pop(key)
1362+
else:
1363+
low_precision_checkpoint_dict[key] = torch.cat(
1364+
data_list, dim=-1
1365+
).contiguous()
1366+
else:
1367+
raise AssertionError(f"{quantization_method} is not supported yet.")
1368+
elif _key_belongs_to(key, mha_layers_split_by_K):
1369+
if "bias" in key:
1370+
continue
11871371
data = low_precision_checkpoint_dict[key]
11881372
if ("scales" in key or "qzeros" in key) and data.shape[0] == 1:
11891373
continue
@@ -1271,10 +1455,10 @@ def shard_low_precision_checkpoint(
12711455
q_head_end = local_rank + 1
12721456
low_precision_checkpoint_dict[key] = data[
12731457
:, q_head_start * dim : q_head_end * dim
1274-
].contiguous()
1458+
]
12751459
else:
12761460
raise AssertionError(f"{quantization_method} is not supported yet.")
1277-
elif any(substring in key for substring in mlp_layers_split_by_K):
1461+
elif _key_belongs_to(key, mlp_layers_split_by_K):
12781462
data = low_precision_checkpoint_dict[key]
12791463
if ("scales" in key or "qzeros" in key) and data.shape[0] == 1:
12801464
continue
@@ -1424,10 +1608,10 @@ def shard_low_precision_checkpoint(
14241608
)
14251609
low_precision_checkpoint_dict[key] = data[
14261610
:, grains_start * dim : grains_end * dim
1427-
].contiguous()
1611+
]
14281612
else:
14291613
raise AssertionError(f"{quantization_method} is not supported yet.")
1430-
elif any(substring in key for substring in lm_head_layers):
1614+
elif _key_belongs_to(key, lm_head_layers):
14311615
# lm_head: [N, K] (not quantized)
14321616
# Same for all quantization methods
14331617
data = low_precision_checkpoint_dict[key]

intel_extension_for_pytorch/transformers/models/cpu/modules/decoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,6 @@ def __init__(self, module, config, tpp=False, woq=False):
414414
if len(w2_shared_compensation_list) > 0
415415
else None
416416
)
417-
418-
print(
419-
"[INFO] Using fused shared MOE WOQ INT8 lowbit weights path..."
420-
)
421417
else:
422418
if (
423419
(self.use_fused_moe or self.use_fused_moe_woq)

0 commit comments

Comments
 (0)