@@ -431,6 +431,7 @@ struct vk_device_struct {
431
431
432
432
// [src/dst 0=fp32,1=fp16]
433
433
vk_pipeline pipeline_gelu[2];
434
+ vk_pipeline pipeline_gelu_erf[2];
434
435
vk_pipeline pipeline_gelu_quick[2];
435
436
vk_pipeline pipeline_silu[2];
436
437
vk_pipeline pipeline_relu[2];
@@ -2761,6 +2762,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2761
2762
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2762
2763
2763
2764
CREATE_UNARY(gelu)
2765
+ CREATE_UNARY(gelu_erf)
2764
2766
CREATE_UNARY(gelu_quick)
2765
2767
CREATE_UNARY(silu)
2766
2768
CREATE_UNARY(relu)
@@ -6481,6 +6483,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6481
6483
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
6482
6484
case GGML_UNARY_OP_GELU:
6483
6485
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6486
+ case GGML_UNARY_OP_GELU_ERF:
6487
+ return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
6484
6488
case GGML_UNARY_OP_GELU_QUICK:
6485
6489
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
6486
6490
case GGML_UNARY_OP_RELU:
@@ -8827,6 +8831,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8827
8831
switch (ggml_get_unary_op(node)) {
8828
8832
case GGML_UNARY_OP_SILU:
8829
8833
case GGML_UNARY_OP_GELU:
8834
+ case GGML_UNARY_OP_GELU_ERF:
8830
8835
case GGML_UNARY_OP_GELU_QUICK:
8831
8836
case GGML_UNARY_OP_RELU:
8832
8837
case GGML_UNARY_OP_TANH:
@@ -9072,6 +9077,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9072
9077
switch (ggml_get_unary_op(node)) {
9073
9078
case GGML_UNARY_OP_SILU:
9074
9079
case GGML_UNARY_OP_GELU:
9080
+ case GGML_UNARY_OP_GELU_ERF:
9075
9081
case GGML_UNARY_OP_GELU_QUICK:
9076
9082
case GGML_UNARY_OP_RELU:
9077
9083
case GGML_UNARY_OP_TANH:
@@ -9289,6 +9295,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9289
9295
switch (ggml_get_unary_op(tensor)) {
9290
9296
case GGML_UNARY_OP_SILU:
9291
9297
case GGML_UNARY_OP_GELU:
9298
+ case GGML_UNARY_OP_GELU_ERF:
9292
9299
case GGML_UNARY_OP_GELU_QUICK:
9293
9300
case GGML_UNARY_OP_RELU:
9294
9301
case GGML_UNARY_OP_TANH:
@@ -10095,6 +10102,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10095
10102
case GGML_OP_UNARY:
10096
10103
switch (ggml_get_unary_op(op)) {
10097
10104
case GGML_UNARY_OP_GELU:
10105
+ case GGML_UNARY_OP_GELU_ERF:
10098
10106
case GGML_UNARY_OP_GELU_QUICK:
10099
10107
case GGML_UNARY_OP_SILU:
10100
10108
case GGML_UNARY_OP_RELU:
@@ -10835,6 +10843,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10835
10843
case GGML_UNARY_OP_GELU:
10836
10844
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
10837
10845
break;
10846
+ case GGML_UNARY_OP_GELU_ERF:
10847
+ tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
10848
+ break;
10838
10849
case GGML_UNARY_OP_GELU_QUICK:
10839
10850
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
10840
10851
break;
0 commit comments