Skip to content

Commit eff5e45

Browse files
authored
add GELU_ERF (#14455)
1 parent a6a4795 commit eff5e45

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ struct vk_device_struct {
431431

432432
// [src/dst 0=fp32,1=fp16]
433433
vk_pipeline pipeline_gelu[2];
434+
vk_pipeline pipeline_gelu_erf[2];
434435
vk_pipeline pipeline_gelu_quick[2];
435436
vk_pipeline pipeline_silu[2];
436437
vk_pipeline pipeline_relu[2];
@@ -2761,6 +2762,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27612762
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
27622763

27632764
CREATE_UNARY(gelu)
2765+
CREATE_UNARY(gelu_erf)
27642766
CREATE_UNARY(gelu_quick)
27652767
CREATE_UNARY(silu)
27662768
CREATE_UNARY(relu)
@@ -6481,6 +6483,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64816483
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
64826484
case GGML_UNARY_OP_GELU:
64836485
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6486+
case GGML_UNARY_OP_GELU_ERF:
6487+
return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
64846488
case GGML_UNARY_OP_GELU_QUICK:
64856489
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
64866490
case GGML_UNARY_OP_RELU:
@@ -8827,6 +8831,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
88278831
switch (ggml_get_unary_op(node)) {
88288832
case GGML_UNARY_OP_SILU:
88298833
case GGML_UNARY_OP_GELU:
8834+
case GGML_UNARY_OP_GELU_ERF:
88308835
case GGML_UNARY_OP_GELU_QUICK:
88318836
case GGML_UNARY_OP_RELU:
88328837
case GGML_UNARY_OP_TANH:
@@ -9072,6 +9077,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90729077
switch (ggml_get_unary_op(node)) {
90739078
case GGML_UNARY_OP_SILU:
90749079
case GGML_UNARY_OP_GELU:
9080+
case GGML_UNARY_OP_GELU_ERF:
90759081
case GGML_UNARY_OP_GELU_QUICK:
90769082
case GGML_UNARY_OP_RELU:
90779083
case GGML_UNARY_OP_TANH:
@@ -9289,6 +9295,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
92899295
switch (ggml_get_unary_op(tensor)) {
92909296
case GGML_UNARY_OP_SILU:
92919297
case GGML_UNARY_OP_GELU:
9298+
case GGML_UNARY_OP_GELU_ERF:
92929299
case GGML_UNARY_OP_GELU_QUICK:
92939300
case GGML_UNARY_OP_RELU:
92949301
case GGML_UNARY_OP_TANH:
@@ -10095,6 +10102,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1009510102
case GGML_OP_UNARY:
1009610103
switch (ggml_get_unary_op(op)) {
1009710104
case GGML_UNARY_OP_GELU:
10105+
case GGML_UNARY_OP_GELU_ERF:
1009810106
case GGML_UNARY_OP_GELU_QUICK:
1009910107
case GGML_UNARY_OP_SILU:
1010010108
case GGML_UNARY_OP_RELU:
@@ -10835,6 +10843,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1083510843
case GGML_UNARY_OP_GELU:
1083610844
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
1083710845
break;
10846+
case GGML_UNARY_OP_GELU_ERF:
10847+
tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
10848+
break;
1083810849
case GGML_UNARY_OP_GELU_QUICK:
1083910850
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
1084010851
break;
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
15+
// ref: https://www.johndcook.com/blog/python_erf/
16+
const float p_erf = 0.3275911f;
17+
const float a1_erf = 0.254829592f;
18+
const float a2_erf = -0.284496736f;
19+
const float a3_erf = 1.421413741f;
20+
const float a4_erf = -1.453152027f;
21+
const float a5_erf = 1.061405429f;
22+
23+
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
24+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
25+
26+
if (i >= p.KX) {
27+
return;
28+
}
29+
30+
const float a = float(data_a[i]);
31+
const float a_div_sqr2 = a * SQRT_2_INV;
32+
const float sign_x = sign(a_div_sqr2);
33+
const float x = abs(a_div_sqr2);
34+
const float t = 1.0f / (1.0f + p_erf * x);
35+
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
36+
const float erf_approx = sign_x * y;
37+
38+
data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
39+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ void process_shaders() {
574574

575575
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
576576
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
577+
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
578+
string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
577579
string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
578580
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
579581
string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

0 commit comments

Comments
 (0)