Skip to content

Commit 343b6e9

Browse files
authored
CANN: update aclnnGroupedMatmulV2 to aclnnGroupedMatmulV3 (#14411)
* [CANN]update to aclnnGroupedMatmulV2 Signed-off-by: noemotiovon <[email protected]> * Support MUL_MAT_ID on 310p Signed-off-by: noemotiovon <[email protected]> * fix editorconfig Signed-off-by: noemotiovon <[email protected]> --------- Signed-off-by: noemotiovon <[email protected]>
1 parent 6a746cf commit 343b6e9

File tree

1 file changed

+65
-4
lines changed

1 file changed

+65
-4
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
#include <aclnnop/aclnn_eq_tensor.h>
6666
#include <aclnnop/aclnn_gt_scalar.h>
6767
#include <aclnnop/aclnn_pow.h>
68-
#include <aclnnop/aclnn_grouped_matmul_v2.h>
68+
#include <aclnnop/aclnn_grouped_matmul_v3.h>
6969
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
7070
#include <float.h>
7171

@@ -2654,6 +2654,67 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
26542654
memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
26552655
}
26562656

2657+
#ifdef ASCEND_310P
2658+
ggml_tensor src0_row = *src0;
2659+
ggml_tensor src1_row = *src1;
2660+
ggml_tensor dst_row = *dst;
2661+
2662+
if (src0->type == GGML_TYPE_F16) {
2663+
src0_row.type = GGML_TYPE_F32;
2664+
}
2665+
2666+
// src0_row [D, M, 1, 1] weight without permute
2667+
src0_row.ne[2] = 1;
2668+
src0_row.ne[3] = 1;
2669+
src0_row.nb[0] = ori_src0_nb[0];
2670+
src0_row.nb[1] = ori_src0_nb[1];
2671+
src0_row.nb[2] = ori_src0_nb[1];
2672+
src0_row.nb[3] = ori_src0_nb[1];
2673+
2674+
// src1_row [D, 1, 1, 1] -> input
2675+
src1_row.ne[1] = 1;
2676+
src1_row.ne[2] = 1;
2677+
src1_row.ne[3] = 1;
2678+
src1_row.nb[2] = nb11;
2679+
src1_row.nb[3] = nb11;
2680+
2681+
// dst_row [M, 1, 1, 1] -> out
2682+
dst_row.ne[1] = 1;
2683+
dst_row.ne[2] = 1;
2684+
dst_row.ne[3] = 1;
2685+
dst_row.nb[2] = nb1;
2686+
dst_row.nb[3] = nb1;
2687+
2688+
//create weight for one row
2689+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2690+
for (int64_t id = 0; id < n_ids; id++) {
2691+
// expert index
2692+
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2693+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
2694+
2695+
// If B = 1 (broadcast), always use 0; otherwise, use id.
2696+
int64_t i11 = (ne11 == 1 ? 0 : id);
2697+
int64_t i12 = iid1;
2698+
2699+
int64_t i1 = id;
2700+
int64_t i2 = i12;
2701+
2702+
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
2703+
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2704+
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2705+
2706+
src0_row.data = src0_tmp_ptr;
2707+
src1_row.data = src1_tmp_ptr;
2708+
dst_row.data = dst_tmp_ptr;
2709+
dst_row.src[0] = &src0_row;
2710+
dst_row.src[1] = &src1_row;
2711+
2712+
ggml_cann_mul_mat(ctx, &dst_row);
2713+
}
2714+
}
2715+
return;
2716+
#endif
2717+
26572718
std::vector<aclTensor*> src0_tensor_vec;
26582719
std::vector<aclTensor*> src1_tensor_vec;
26592720
std::vector<aclTensor*> dst_tensor_vec;
@@ -2701,9 +2762,9 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27012762
}
27022763

27032764
size_t GROUP_SIZE = 128;
2704-
// GroupedMatmulV2 required tensor_list.size < 128
2765+
// GroupedMatmulV3 required tensor_list.size < 128
27052766
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
2706-
// split and call GroupedMatmulV2
2767+
// split and call GroupedMatmulV3
27072768
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
27082769
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
27092770
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
@@ -2713,7 +2774,7 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27132774
aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
27142775
aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
27152776

2716-
GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
2777+
GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV3, src1_tensor_list, src0_tensor_list,
27172778
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
27182779

27192780
ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);

0 commit comments

Comments
 (0)