Skip to content

Commit eb729ff

Browse files
洪炜杰hahazhky
authored andcommitted
add fix routing for performance test
Signed-off-by: zhky <[email protected]>
1 parent 3442fbd commit eb729ff

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666
lambda: os.getenv("C_COMPILER", None),
6767
"VLLM_VERSION":
6868
lambda: os.getenv("VLLM_VERSION", None),
69+
# dispatch tokens to experts averagely for performance test
70+
"VLLM_ENABLE_FIX_ROUTE":
71+
lambda: bool(int(os.getenv("VLLM_ENABLE_FIX_ROUTE", '0'))),
6972
}
7073

7174
# end-env-vars-definition

vllm_ascend/ops/fused_moe.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3838
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
39+
VLLM_ENABLE_FIX_ROUTE: bool = envs_ascend.VLLM_ENABLE_FIX_ROUTE
3940

4041

4142
def fused_experts_with_mc2(
@@ -50,6 +51,14 @@ def fused_experts_with_mc2(
5051
) -> torch.Tensor:
5152
global_bs = 0
5253
moe_expert_num = len(expert_map)
54+
55+
rank = torch.distributed.get_rank()
56+
if VLLM_ENABLE_FIX_ROUTE:
57+
step = hidden_states.shape[0] * top_k
58+
uniform_topk_list = [(i + rank) % moe_expert_num
59+
for i in range(rank * step, (rank + 1) * step)]
60+
topk_ids = torch.Tensor(uniform_topk_list).int().view(
61+
hidden_states.shape[0], -1).to(hidden_states.device)
5362
kwargs = {
5463
"x": hidden_states,
5564
"expert_ids": topk_ids,
@@ -59,8 +68,6 @@ def fused_experts_with_mc2(
5968
"global_bs": global_bs,
6069
}
6170

62-
rank = torch.distributed.get_rank()
63-
6471
quant_mode = 0
6572
ep_group = get_ep_group().device_group
6673
local_rank = torch.distributed.get_rank(group=ep_group)
@@ -89,15 +96,20 @@ def fused_experts_with_mc2(
8996
0:5]
9097

9198
w1 = w1.transpose(1, 2)
92-
expert_token_nums = torch.cumsum(expert_token_nums,
93-
dim=0,
94-
dtype=torch.int64)
95-
group_list = expert_token_nums.to(torch.int64)
99+
100+
if VLLM_ENABLE_FIX_ROUTE:
101+
uniform_group_list = hidden_states.shape[0] * \
102+
all_to_all_group_size * top_k // moe_expert_num
103+
group_list = torch.Tensor([uniform_group_list] *
104+
w1.shape[0]).long().to(hidden_states.device)
105+
else:
106+
group_list = expert_token_nums
96107
gate_up_out_list = torch_npu.npu_grouped_matmul(
97108
x=[expand_x],
98109
weight=[w1],
99110
split_item=2,
100-
group_list_type=0,
111+
# 1 means count mode, to avoid cumulative operation of the group list
112+
group_list_type=1,
101113
group_type=0,
102114
group_list=group_list,
103115
)
@@ -111,7 +123,7 @@ def fused_experts_with_mc2(
111123
x=[gate_up_out],
112124
weight=[w2],
113125
split_item=2,
114-
group_list_type=0,
126+
group_list_type=1,
115127
group_type=0,
116128
group_list=group_list,
117129
)

0 commit comments

Comments
 (0)