44
44
45
45
VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
46
46
USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
47
+ VLLM_ENABLE_FIX_ROUTE : bool = envs_ascend .VLLM_ENABLE_FIX_ROUTE
47
48
48
49
49
50
def fused_experts_with_mc2 (
@@ -58,6 +59,14 @@ def fused_experts_with_mc2(
58
59
) -> torch .Tensor :
59
60
global_bs = 0
60
61
moe_expert_num = len (expert_map )
62
+
63
+ rank = torch .distributed .get_rank ()
64
+ if VLLM_ENABLE_FIX_ROUTE :
65
+ step = hidden_states .shape [0 ] * top_k
66
+ uniform_topk_list = [(i + rank ) % moe_expert_num
67
+ for i in range (rank * step , (rank + 1 ) * step )]
68
+ topk_ids = torch .Tensor (uniform_topk_list ).int ().view (
69
+ hidden_states .shape [0 ], - 1 ).to (hidden_states .device )
61
70
kwargs = {
62
71
"x" : hidden_states ,
63
72
"expert_ids" : topk_ids ,
@@ -67,8 +76,6 @@ def fused_experts_with_mc2(
67
76
"global_bs" : global_bs ,
68
77
}
69
78
70
- rank = torch .distributed .get_rank ()
71
-
72
79
quant_mode = 0
73
80
ep_group = get_ep_group ().device_group
74
81
local_rank = torch .distributed .get_rank (group = ep_group )
@@ -97,15 +104,20 @@ def fused_experts_with_mc2(
97
104
0 :5 ]
98
105
99
106
w1 = w1 .transpose (1 , 2 )
100
- expert_token_nums = torch .cumsum (expert_token_nums ,
101
- dim = 0 ,
102
- dtype = torch .int64 )
103
- group_list = expert_token_nums .to (torch .int64 )
107
+
108
+ if VLLM_ENABLE_FIX_ROUTE :
109
+ uniform_group_list = hidden_states .shape [0 ] * \
110
+ all_to_all_group_size * top_k // moe_expert_num
111
+ group_list = torch .Tensor ([uniform_group_list ] *
112
+ w1 .shape [0 ]).long ().to (hidden_states .device )
113
+ else :
114
+ group_list = expert_token_nums
104
115
gate_up_out_list = torch_npu .npu_grouped_matmul (
105
116
x = [expand_x ],
106
117
weight = [w1 ],
107
118
split_item = 2 ,
108
- group_list_type = 0 ,
119
+ # 1 means count mode, to avoid cumulative operation of the group list
120
+ group_list_type = 1 ,
109
121
group_type = 0 ,
110
122
group_list = group_list ,
111
123
)
@@ -119,7 +131,7 @@ def fused_experts_with_mc2(
119
131
x = [gate_up_out ],
120
132
weight = [w2 ],
121
133
split_item = 2 ,
122
- group_list_type = 0 ,
134
+ group_list_type = 1 ,
123
135
group_type = 0 ,
124
136
group_list = group_list ,
125
137
)
0 commit comments