Skip to content

Commit a0535ff

Browse files
CISCggerganov0cc4mqnixsynapsejeffbolznv
authored
ggml : implement REGLU/GEGLU/SWIGLU ops (#14158)
* implement unary REGLU/GEGLU/SWIGLU cpu ops * relax constraints * duplicate shape of source * fix ggml_vec_geglu_f16 * special case gated ops * implement unary REGLU/GEGLU/SWIGLU cuda ops * tighten constraints again * refactor into GGML_GLU_OP * metal : add glu kernels ggml-ci * add CUDA_GLU_BLOCK_SIZE [no ci] * more constraints and use 64bit ints ggml-ci * 64bit multiplication [no ci] * implement swapped variants (cpu/cuda) * update comment [no ci] ggml-ci * Vulkan: Add GLU ops and shaders * SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate * ggml : implement GLU for split up/gate (#14181) * implement GLU for split up/gate * add tests for ggml_glu_split * Vulkan: Implement glu_split logic and shader support * add split to logging [no ci] * SYCL: refactor element_size ops and add split up and gate support to gated kernels * SYCL: switch GEGLU to use tanh approximation --------- Co-authored-by: 0cc4m <[email protected]> Co-authored-by: Akarshan <[email protected]> * GGML: increase OP count in assertion * Refactor: Optimize SYCL element-wise operations with unary function inlining This commit refactors the SYCL element-wise operations to improve performance by: - Inlining unary operations (sgn, abs, elu, gelu, silu, etc.) to reduce kernel launch overhead. - Introducing helper functions `op_xxx` for each unary operation to encapsulate the logic. - Replacing direct kernel calls with calls to these inlined functions. - Using `__dpct_inline__` to encourage compiler inlining. - Minor code cleanup and consistency improvements. The changes aim to reduce kernel launch overhead and improve the overall efficiency of element-wise operations on SYCL devices. * vulkan: Increase workgroup size for GLU, for performance (#14345) * vulkan: Increase workgroup size for GLU, for performance * vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup * merge fix * metal : add support for split and swap ggml-ci --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: 0cc4m <[email protected]> Co-authored-by: Akarshan <[email protected]> Co-authored-by: Jeff Bolz <[email protected]>
1 parent bd9c981 commit a0535ff

26 files changed

+2044
-1071
lines changed

ggml/include/ggml.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,8 @@ extern "C" {
520520
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
521521
GGML_OP_OPT_STEP_ADAMW,
522522

523+
GGML_OP_GLU,
524+
523525
GGML_OP_COUNT,
524526
};
525527

@@ -543,6 +545,14 @@ extern "C" {
543545
GGML_UNARY_OP_COUNT,
544546
};
545547

548+
enum ggml_glu_op {
549+
GGML_GLU_OP_REGLU,
550+
GGML_GLU_OP_GEGLU,
551+
GGML_GLU_OP_SWIGLU,
552+
553+
GGML_GLU_OP_COUNT,
554+
};
555+
546556
enum ggml_object_type {
547557
GGML_OBJECT_TYPE_TENSOR,
548558
GGML_OBJECT_TYPE_GRAPH,
@@ -658,6 +668,7 @@ extern "C" {
658668
GGML_API const char * ggml_op_symbol(enum ggml_op op);
659669

660670
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
671+
GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
661672
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
662673

663674
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@@ -762,6 +773,7 @@ extern "C" {
762773
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
763774

764775
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
776+
GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
765777

766778
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
767779
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@@ -1090,6 +1102,63 @@ extern "C" {
10901102
struct ggml_context * ctx,
10911103
struct ggml_tensor * a);
10921104

1105+
// gated linear unit ops
1106+
// A: n columns, r rows,
1107+
// result is n / 2 columns, r rows,
1108+
// expects gate in second half of row, unless swapped is true
1109+
GGML_API struct ggml_tensor * ggml_glu(
1110+
struct ggml_context * ctx,
1111+
struct ggml_tensor * a,
1112+
enum ggml_glu_op op,
1113+
bool swapped);
1114+
1115+
GGML_API struct ggml_tensor * ggml_reglu(
1116+
struct ggml_context * ctx,
1117+
struct ggml_tensor * a);
1118+
1119+
GGML_API struct ggml_tensor * ggml_reglu_swapped(
1120+
struct ggml_context * ctx,
1121+
struct ggml_tensor * a);
1122+
1123+
GGML_API struct ggml_tensor * ggml_geglu(
1124+
struct ggml_context * ctx,
1125+
struct ggml_tensor * a);
1126+
1127+
GGML_API struct ggml_tensor * ggml_geglu_swapped(
1128+
struct ggml_context * ctx,
1129+
struct ggml_tensor * a);
1130+
1131+
GGML_API struct ggml_tensor * ggml_swiglu(
1132+
struct ggml_context * ctx,
1133+
struct ggml_tensor * a);
1134+
1135+
GGML_API struct ggml_tensor * ggml_swiglu_swapped(
1136+
struct ggml_context * ctx,
1137+
struct ggml_tensor * a);
1138+
1139+
// A: n columns, r rows,
1140+
// B: n columns, r rows,
1141+
GGML_API struct ggml_tensor * ggml_glu_split(
1142+
struct ggml_context * ctx,
1143+
struct ggml_tensor * a,
1144+
struct ggml_tensor * b,
1145+
enum ggml_glu_op op);
1146+
1147+
GGML_API struct ggml_tensor * ggml_reglu_split(
1148+
struct ggml_context * ctx,
1149+
struct ggml_tensor * a,
1150+
struct ggml_tensor * b);
1151+
1152+
GGML_API struct ggml_tensor * ggml_geglu_split(
1153+
struct ggml_context * ctx,
1154+
struct ggml_tensor * a,
1155+
struct ggml_tensor * b);
1156+
1157+
GGML_API struct ggml_tensor * ggml_swiglu_split(
1158+
struct ggml_context * ctx,
1159+
struct ggml_tensor * a,
1160+
struct ggml_tensor * b);
1161+
10931162
// normalize along rows
10941163
GGML_API struct ggml_tensor * ggml_norm(
10951164
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19491949
{
19501950
ggml_compute_forward_unary(params, tensor);
19511951
} break;
1952+
case GGML_OP_GLU:
1953+
{
1954+
ggml_compute_forward_glu(params, tensor);
1955+
} break;
19521956
case GGML_OP_GET_REL_POS:
19531957
{
19541958
ggml_compute_forward_get_rel_pos(params, tensor);
@@ -2159,6 +2163,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21592163
GGML_ABORT("fatal error");
21602164
}
21612165
break;
2166+
case GGML_OP_GLU:
2167+
switch (ggml_get_glu_op(node)) {
2168+
case GGML_GLU_OP_REGLU:
2169+
case GGML_GLU_OP_GEGLU:
2170+
case GGML_GLU_OP_SWIGLU:
2171+
{
2172+
n_tasks = n_threads;
2173+
} break;
2174+
default:
2175+
GGML_ABORT("fatal error");
2176+
}
2177+
break;
21622178
case GGML_OP_SILU_BACK:
21632179
case GGML_OP_MUL:
21642180
case GGML_OP_DIV:

0 commit comments

Comments
 (0)