36
36
37
37
VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
38
38
USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
39
+ VLLM_ENABLE_FIX_ROUTE : bool = envs_ascend .VLLM_ENABLE_FIX_ROUTE
39
40
40
41
41
42
def fused_experts_with_mc2 (
@@ -50,6 +51,14 @@ def fused_experts_with_mc2(
50
51
) -> torch .Tensor :
51
52
global_bs = 0
52
53
moe_expert_num = len (expert_map )
54
+
55
+ rank = torch .distributed .get_rank ()
56
+ if VLLM_ENABLE_FIX_ROUTE :
57
+ step = hidden_states .shape [0 ] * top_k
58
+ uniform_topk_list = [(i + rank ) % moe_expert_num
59
+ for i in range (rank * step , (rank + 1 ) * step )]
60
+ topk_ids = torch .Tensor (uniform_topk_list ).int ().view (
61
+ hidden_states .shape [0 ], - 1 ).to (hidden_states .device )
53
62
kwargs = {
54
63
"x" : hidden_states ,
55
64
"expert_ids" : topk_ids ,
@@ -59,8 +68,6 @@ def fused_experts_with_mc2(
59
68
"global_bs" : global_bs ,
60
69
}
61
70
62
- rank = torch .distributed .get_rank ()
63
-
64
71
quant_mode = 0
65
72
ep_group = get_ep_group ().device_group
66
73
local_rank = torch .distributed .get_rank (group = ep_group )
@@ -89,15 +96,20 @@ def fused_experts_with_mc2(
89
96
0 :5 ]
90
97
91
98
w1 = w1 .transpose (1 , 2 )
92
- expert_token_nums = torch .cumsum (expert_token_nums ,
93
- dim = 0 ,
94
- dtype = torch .int64 )
95
- group_list = expert_token_nums .to (torch .int64 )
99
+
100
+ if VLLM_ENABLE_FIX_ROUTE :
101
+ uniform_group_list = hidden_states .shape [0 ] * \
102
+ all_to_all_group_size * top_k // moe_expert_num
103
+ group_list = torch .Tensor ([uniform_group_list ] *
104
+ w1 .shape [0 ]).long ().to (hidden_states .device )
105
+ else :
106
+ group_list = expert_token_nums
96
107
gate_up_out_list = torch_npu .npu_grouped_matmul (
97
108
x = [expand_x ],
98
109
weight = [w1 ],
99
110
split_item = 2 ,
100
- group_list_type = 0 ,
111
+ # 1 means count mode, to avoid cumulative operation of the group list
112
+ group_list_type = 1 ,
101
113
group_type = 0 ,
102
114
group_list = group_list ,
103
115
)
@@ -111,7 +123,7 @@ def fused_experts_with_mc2(
111
123
x = [gate_up_out ],
112
124
weight = [w2 ],
113
125
split_item = 2 ,
114
- group_list_type = 0 ,
126
+ group_list_type = 1 ,
115
127
group_type = 0 ,
116
128
group_list = group_list ,
117
129
)
0 commit comments