Skip to content

Commit bd9c981

Browse files
jeffbolznvslaren
andauthored
vulkan: Add fusion support for RMS_NORM+MUL (#14366)
* vulkan: Add fusion support for RMS_NORM+MUL - Add a use_count to ggml_tensor, so we can detect if an output is used more than once. - Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor. - Add detection logic and basic fusion logic in ggml-vulkan. - Add some testing support for fusion. Rather than computing one node at a time, allow for computing the whole graph and just testing one node's results. Add rms_norm_mul tests and enable a llama test. * extract some common fusion logic * fix -Winconsistent-missing-override * move ggml_can_fuse to a common function * build fix * C and C++ versions of can_fuse * move use count to the graph to avoid data races and double increments when used in multiple threads * use hash table lookup to find node index * change use_counts to be indexed by hash table slot * minimize hash lookups style fixes * last node doesn't need single use. fix type. handle mul operands being swapped. * remove redundant parameter --------- Co-authored-by: slaren <[email protected]>
1 parent 27208bf commit bd9c981

File tree

8 files changed

+261
-54
lines changed

8 files changed

+261
-54
lines changed

ggml/include/ggml-backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ extern "C" {
339339
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
340340

341341
// Compare the output of two backends
342-
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
342+
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
343343

344344
// Tensor initialization
345345
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);

ggml/src/ggml-backend.cpp

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -817,8 +817,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
817817
}
818818
if (sched->debug > 1) {
819819
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
820-
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
821-
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
820+
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
821+
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
822+
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
822823
for (int j = 0; j < GGML_MAX_SRC; j++) {
823824
struct ggml_tensor * src = node->src[j];
824825
if (src == NULL) {
@@ -1826,7 +1827,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
18261827
ggml_free(copy.ctx_unallocated);
18271828
}
18281829

1829-
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
1830+
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
18301831
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
18311832
if (copy.buffer == NULL) {
18321833
return false;
@@ -1837,28 +1838,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
18371838

18381839
assert(g1->n_nodes == g2->n_nodes);
18391840

1840-
for (int i = 0; i < g1->n_nodes; i++) {
1841-
struct ggml_tensor * t1 = g1->nodes[i];
1842-
struct ggml_tensor * t2 = g2->nodes[i];
1841+
if (test_node != nullptr) {
1842+
// Compute the whole graph and only test the output for a specific tensor
1843+
ggml_backend_graph_compute(backend1, g1);
1844+
ggml_backend_graph_compute(backend2, g2);
18431845

1844-
assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
1846+
int test_node_idx = -1;
1847+
for (int i = 0; i < g1->n_nodes; i++) {
1848+
struct ggml_tensor * t1 = g1->nodes[i];
1849+
if (t1 == test_node) {
1850+
test_node_idx = i;
1851+
break;
1852+
}
1853+
}
1854+
GGML_ASSERT(test_node_idx != -1);
18451855

1846-
struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1847-
struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
1856+
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
1857+
} else {
1858+
for (int i = 0; i < g1->n_nodes; i++) {
1859+
struct ggml_tensor * t1 = g1->nodes[i];
1860+
struct ggml_tensor * t2 = g2->nodes[i];
18481861

1849-
ggml_backend_graph_compute(backend1, &g1v);
1850-
ggml_backend_graph_compute(backend2, &g2v);
1862+
assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
18511863

1852-
if (ggml_is_view_op(t1->op)) {
1853-
continue;
1854-
}
1864+
struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1865+
struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
18551866

1856-
// compare results, calculate rms etc
1857-
if (!callback(i, t1, t2, user_data)) {
1858-
break;
1867+
ggml_backend_graph_compute(backend1, &g1v);
1868+
ggml_backend_graph_compute(backend2, &g2v);
1869+
1870+
if (ggml_is_view_op(t1->op)) {
1871+
continue;
1872+
}
1873+
1874+
// compare results, calculate rms etc
1875+
if (!callback(i, t1, t2, user_data)) {
1876+
break;
1877+
}
18591878
}
18601879
}
1861-
18621880
ggml_backend_graph_copy_free(copy);
18631881

18641882
return true;

ggml/src/ggml-impl.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ struct ggml_cgraph {
301301
struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
302302
struct ggml_tensor ** grad_accs; // accumulators for node gradients
303303
struct ggml_tensor ** leafs; // tensors with constant data
304+
int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
304305

305306
struct ggml_hash_set visited_hash_set;
306307

@@ -467,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
467468
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
468469
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
469470

471+
// return true if the node's results are only used by N other nodes
472+
// and can be fused into their calculations.
473+
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
474+
const struct ggml_tensor * node = cgraph->nodes[node_idx];
475+
476+
// check the use count against how many we're replacing
477+
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
478+
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
479+
return false;
480+
}
481+
482+
// if node is a view, some other node might be using the intermediate result
483+
// via the view source.
484+
if (node->view_src) {
485+
return false;
486+
}
487+
488+
// If the user requested output for the node, can't fuse
489+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
490+
return false;
491+
}
492+
493+
return true;
494+
}
495+
496+
// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
497+
// and are fusable. Nodes are considered fusable according to this function if:
498+
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
499+
// - all nodes except the last are a src of the following node.
500+
// - all nodes are the same shape.
501+
// TODO: Consider allowing GGML_OP_NONE nodes in between
502+
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
503+
if (node_idx + num_ops > cgraph->n_nodes) {
504+
return false;
505+
}
506+
507+
for (int i = 0; i < num_ops; ++i) {
508+
struct ggml_tensor * node = cgraph->nodes[node_idx + i];
509+
if (node->op != ops[i]) {
510+
return false;
511+
}
512+
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
513+
return false;
514+
}
515+
if (i > 0) {
516+
struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
517+
if (node->src[0] != prev && node->src[1] != prev) {
518+
return false;
519+
}
520+
if (!ggml_are_same_shape(node, prev)) {
521+
return false;
522+
}
523+
}
524+
}
525+
return true;
526+
}
527+
470528
#ifdef __cplusplus
471529
}
472530
#endif
473531

474532
#ifdef __cplusplus
533+
#include <initializer_list>
475534
#include <vector>
476535

536+
// nicer C++ syntax for ggml_can_fuse
537+
inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
538+
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
539+
}
540+
477541
// expose GGUF internals for test code
478542
GGML_API size_t gguf_type_size(enum gguf_type type);
479543
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ struct vk_device_struct {
425425
vk_pipeline pipeline_norm_f32;
426426
vk_pipeline pipeline_group_norm_f32;
427427
vk_pipeline pipeline_rms_norm_f32;
428+
vk_pipeline pipeline_rms_norm_mul_f32;
428429
vk_pipeline pipeline_rms_norm_back_f32;
429430
vk_pipeline pipeline_l2_norm_f32;
430431

@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
978979

979980
vk_command_pool compute_cmd_pool;
980981
vk_command_pool transfer_cmd_pool;
982+
983+
// number of additional consecutive nodes that are being fused with the
984+
// node currently being processed
985+
uint32_t num_additional_fused_ops {};
981986
};
982987

983988
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
26552660

26562661
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26572662
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2658-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2663+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2664+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
26592665
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26602666
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26612667

@@ -6430,7 +6436,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64306436
return nullptr;
64316437
case GGML_OP_RMS_NORM:
64326438
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6433-
return ctx->device->pipeline_rms_norm_f32;
6439+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
64346440
}
64356441
return nullptr;
64366442
case GGML_OP_RMS_NORM_BACK:
@@ -7530,18 +7536,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
75307536
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
75317537
}
75327538

7533-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7539+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
75347540
float * op_params = (float *)dst->op_params;
75357541
const uint32_t src0_type_size = ggml_type_size(src0->type);
7542+
const uint32_t src1_type_size = ggml_type_size(src1->type);
75367543
const uint32_t dst_type_size = ggml_type_size(dst->type);
75377544

7538-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7545+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
75397546
(uint32_t)ggml_nelements(src0),
7540-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7541-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7547+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7548+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7549+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
75427550
0,
7543-
op_params[0], 0.0f,
7544-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7551+
op_params[0], 0.0f, 0,
75457552
}, dryrun);
75467553
}
75477554

@@ -8736,7 +8743,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
87368743

87378744
// Returns true if node has enqueued work into the queue, false otherwise
87388745
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8739-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8746+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8747+
ggml_tensor * node = cgraph->nodes[node_idx];
87408748
if (ggml_is_empty(node) || !node->buffer) {
87418749
return false;
87428750
}
@@ -8974,8 +8982,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
89748982

89758983
break;
89768984
case GGML_OP_RMS_NORM:
8977-
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8978-
8985+
if (ctx->num_additional_fused_ops > 0) {
8986+
// fused rms_norm + mul
8987+
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
8988+
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
8989+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
8990+
} else {
8991+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
8992+
}
89798993
break;
89808994
case GGML_OP_RMS_NORM_BACK:
89818995
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9710,10 +9724,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97109724

97119725
uint64_t total_mat_mul_bytes = 0;
97129726
for (int i = 0; i < cgraph->n_nodes; i++) {
9713-
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
9727+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9728+
ctx->num_additional_fused_ops = 1;
9729+
}
9730+
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
97149731
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
97159732
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97169733
}
9734+
i += ctx->num_additional_fused_ops;
9735+
ctx->num_additional_fused_ops = 0;
97179736
}
97189737
if (ctx->device->need_compiles) {
97199738
ggml_vk_load_shaders(ctx->device);
@@ -9775,14 +9794,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97759794
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97769795
}
97779796

9797+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9798+
ctx->num_additional_fused_ops = 1;
9799+
}
9800+
97789801
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
97799802
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
97809803
bool submit = (submitted_nodes >= nodes_per_submit) ||
97819804
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9782-
(i == last_node) ||
9805+
(i + ctx->num_additional_fused_ops == last_node) ||
97839806
(almost_ready && !ctx->almost_ready_fence_pending);
97849807

9785-
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
9808+
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
97869809

97879810
if (vk_perf_logger_enabled) {
97889811
if (ctx->compute_ctx.expired()) {
@@ -9792,7 +9815,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97929815
} else {
97939816
compute_ctx = ctx->compute_ctx.lock();
97949817
}
9795-
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
9818+
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
9819+
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
9820+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
9821+
}
97969822
}
97979823

97989824
if (enqueued) {
@@ -9814,6 +9840,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
98149840
}
98159841
submit_count++;
98169842
}
9843+
i += ctx->num_additional_fused_ops;
9844+
ctx->num_additional_fused_ops = 0;
98179845
}
98189846

98199847
if (vk_perf_logger_enabled) {

ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#version 450
22

3-
#include "generic_unary_head.comp"
3+
#include "generic_binary_head.comp"
44
#include "types.comp"
55

66
#extension GL_EXT_control_flow_attributes : enable
77
#define BLOCK_SIZE 512
88

9+
layout (constant_id = 1) const bool do_multiply = false;
10+
911
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
1012

1113
shared FLOAT_TYPE sum[BLOCK_SIZE];
@@ -25,6 +27,7 @@ void main() {
2527
const uint stride_sample = p.nb03;
2628

2729
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
30+
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
2831
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
2932

3033
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
@@ -46,7 +49,13 @@ void main() {
4649
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
4750
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
4851

49-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50-
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
52+
if (do_multiply) {
53+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
54+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
55+
}
56+
} else {
57+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
58+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
59+
}
5160
}
5261
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ void process_shaders() {
497497
// Norms
498498
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
499499
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
500-
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
500+
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
501501
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
502502
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
503503

0 commit comments

Comments
 (0)