File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -89,15 +89,14 @@ def fused_experts_with_mc2(
89
89
0 :5 ]
90
90
91
91
w1 = w1 .transpose (1 , 2 )
92
- expert_token_nums = torch .cumsum (expert_token_nums ,
93
- dim = 0 ,
94
- dtype = torch .int64 )
92
+
95
93
group_list = expert_token_nums .to (torch .int64 )
96
94
gate_up_out_list = torch_npu .npu_grouped_matmul (
97
95
x = [expand_x ],
98
96
weight = [w1 ],
99
97
split_item = 2 ,
100
- group_list_type = 0 ,
98
+ # 1 means count mode, to avoid cumulative operation of the group list
99
+ group_list_type = 1 ,
101
100
group_type = 0 ,
102
101
group_list = group_list ,
103
102
)
@@ -111,7 +110,7 @@ def fused_experts_with_mc2(
111
110
x = [gate_up_out ],
112
111
weight = [w2 ],
113
112
split_item = 2 ,
114
- group_list_type = 0 ,
113
+ group_list_type = 1 ,
115
114
group_type = 0 ,
116
115
group_list = group_list ,
117
116
)
You can’t perform that action at this time.
0 commit comments