Skip to content

Commit eb219f4

Browse files
jeffbolznvslaren
authored andcommitted
vulkan: Add fusion support for RMS_NORM+MUL (ggml-org#14366)
* vulkan: Add fusion support for RMS_NORM+MUL - Add a use_count to ggml_tensor, so we can detect if an output is used more than once. - Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor. - Add detection logic and basic fusion logic in ggml-vulkan. - Add some testing support for fusion. Rather than computing one node at a time, allow for computing the whole graph and just testing one node's results. Add rms_norm_mul tests and enable a llama test. * extract some common fusion logic * fix -Winconsistent-missing-override * move ggml_can_fuse to a common function * build fix * C and C++ versions of can_fuse * move use count to the graph to avoid data races and double increments when used in multiple threads * use hash table lookup to find node index * change use_counts to be indexed by hash table slot * minimize hash lookups style fixes * last node doesn't need single use. fix type. handle mul operands being swapped. * remove redundant parameter --------- Co-authored-by: slaren <slarengh@gmail.com>
1 parent 38733a0 commit eb219f4

File tree

3 files changed

+147
-31
lines changed

3 files changed

+147
-31
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ struct vk_device_struct {
425425
vk_pipeline pipeline_norm_f32;
426426
vk_pipeline pipeline_group_norm_f32;
427427
vk_pipeline pipeline_rms_norm_f32;
428+
vk_pipeline pipeline_rms_norm_mul_f32;
428429
vk_pipeline pipeline_rms_norm_back_f32;
429430
vk_pipeline pipeline_l2_norm_f32;
430431

@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
978979

979980
vk_command_pool compute_cmd_pool;
980981
vk_command_pool transfer_cmd_pool;
982+
983+
// number of additional consecutive nodes that are being fused with the
984+
// node currently being processed
985+
uint32_t num_additional_fused_ops {};
981986
};
982987

983988
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
26552660

26562661
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26572662
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2658-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2663+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2664+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
26592665
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26602666
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26612667

@@ -6430,7 +6436,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64306436
return nullptr;
64316437
case GGML_OP_RMS_NORM:
64326438
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6433-
return ctx->device->pipeline_rms_norm_f32;
6439+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
64346440
}
64356441
return nullptr;
64366442
case GGML_OP_RMS_NORM_BACK:
@@ -7530,18 +7536,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
75307536
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
75317537
}
75327538

7533-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7539+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
75347540
float * op_params = (float *)dst->op_params;
75357541
const uint32_t src0_type_size = ggml_type_size(src0->type);
7542+
const uint32_t src1_type_size = ggml_type_size(src1->type);
75367543
const uint32_t dst_type_size = ggml_type_size(dst->type);
75377544

7538-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7545+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
75397546
(uint32_t)ggml_nelements(src0),
7540-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7541-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7547+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7548+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7549+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
75427550
0,
7543-
op_params[0], 0.0f,
7544-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7551+
op_params[0], 0.0f, 0,
75457552
}, dryrun);
75467553
}
75477554

@@ -8736,7 +8743,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
87368743

87378744
// Returns true if node has enqueued work into the queue, false otherwise
87388745
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8739-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8746+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8747+
ggml_tensor * node = cgraph->nodes[node_idx];
87408748
if (ggml_is_empty(node) || !node->buffer) {
87418749
return false;
87428750
}
@@ -8974,8 +8982,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
89748982

89758983
break;
89768984
case GGML_OP_RMS_NORM:
8977-
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8978-
8985+
if (ctx->num_additional_fused_ops > 0) {
8986+
// fused rms_norm + mul
8987+
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
8988+
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
8989+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
8990+
} else {
8991+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
8992+
}
89798993
break;
89808994
case GGML_OP_RMS_NORM_BACK:
89818995
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9710,10 +9724,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97109724

97119725
uint64_t total_mat_mul_bytes = 0;
97129726
for (int i = 0; i < cgraph->n_nodes; i++) {
9713-
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
9727+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9728+
ctx->num_additional_fused_ops = 1;
9729+
}
9730+
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
97149731
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
97159732
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97169733
}
9734+
i += ctx->num_additional_fused_ops;
9735+
ctx->num_additional_fused_ops = 0;
97179736
}
97189737
if (ctx->device->need_compiles) {
97199738
ggml_vk_load_shaders(ctx->device);
@@ -9775,14 +9794,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97759794
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97769795
}
97779796

9797+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9798+
ctx->num_additional_fused_ops = 1;
9799+
}
9800+
97789801
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
97799802
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
97809803
bool submit = (submitted_nodes >= nodes_per_submit) ||
97819804
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9782-
(i == last_node) ||
9805+
(i + ctx->num_additional_fused_ops == last_node) ||
97839806
(almost_ready && !ctx->almost_ready_fence_pending);
97849807

9785-
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
9808+
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
97869809

97879810
if (vk_perf_logger_enabled) {
97889811
if (ctx->compute_ctx.expired()) {
@@ -9792,7 +9815,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97929815
} else {
97939816
compute_ctx = ctx->compute_ctx.lock();
97949817
}
9795-
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
9818+
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
9819+
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
9820+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
9821+
}
97969822
}
97979823

97989824
if (enqueued) {
@@ -9814,6 +9840,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
98149840
}
98159841
submit_count++;
98169842
}
9843+
i += ctx->num_additional_fused_ops;
9844+
ctx->num_additional_fused_ops = 0;
98179845
}
98189846

98199847
if (vk_perf_logger_enabled) {

ggml/src/ggml.c

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5841,19 +5841,32 @@ static void ggml_compute_backward(
58415841
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
58425842
}
58435843

5844-
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
5844+
static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
58455845
// check if already visited
5846-
if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
5847-
return;
5846+
size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
5847+
GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
5848+
if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
5849+
// This is the first time we see this node in the current graph.
5850+
cgraph->visited_hash_set.keys[node_hash_pos] = node;
5851+
ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
5852+
cgraph->use_counts[node_hash_pos] = 0;
5853+
} else {
5854+
// already visited
5855+
return node_hash_pos;
58485856
}
58495857

58505858
for (int i = 0; i < GGML_MAX_SRC; ++i) {
58515859
const int k =
58525860
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
58535861
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
5854-
/* unknown order, just fall back to using i*/ i;
5855-
if (node->src[k]) {
5856-
ggml_visit_parents(cgraph, node->src[k]);
5862+
/* unknown order, just fall back to using i */ i;
5863+
5864+
struct ggml_tensor * src = node->src[k];
5865+
if (src) {
5866+
size_t src_hash_pos = ggml_visit_parents(cgraph, src);
5867+
5868+
// Update the use count for this operand.
5869+
cgraph->use_counts[src_hash_pos]++;
58575870
}
58585871
}
58595872

@@ -5877,6 +5890,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58775890
cgraph->nodes[cgraph->n_nodes] = node;
58785891
cgraph->n_nodes++;
58795892
}
5893+
5894+
return node_hash_pos;
58805895
}
58815896

58825897
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
@@ -6014,6 +6029,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
60146029
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
60156030
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
60166031
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6032+
incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
60176033
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
60186034
if (grads) {
60196035
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -6043,11 +6059,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60436059

60446060
void * p = cgraph + 1;
60456061

6046-
struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6047-
struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6048-
struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6049-
struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6050-
struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6062+
struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6063+
struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6064+
int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6065+
struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6066+
struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6067+
struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
60516068

60526069
ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
60536070

@@ -6062,6 +6079,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60626079
/*.grads =*/ grads_ptr,
60636080
/*.grad_accs =*/ grad_accs_ptr,
60646081
/*.leafs =*/ leafs_ptr,
6082+
/*.use_counts =*/ use_counts_ptr,
60656083
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
60666084
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
60676085
};
@@ -6088,7 +6106,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
60886106
/*.grads =*/ NULL, // gradients would need visited_hash_set
60896107
/*.grad_accs =*/ NULL,
60906108
/*.leafs =*/ NULL,
6091-
/*.visited_hash_set =*/ { 0, NULL, NULL },
6109+
/*.use_counts =*/ cgraph0->use_counts,
6110+
/*.visited_hash_set =*/ cgraph0->visited_hash_set,
60926111
/*.order =*/ cgraph0->order,
60936112
};
60946113

@@ -6115,7 +6134,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
61156134
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
61166135
// copy all hashset keys (tensors) that are in use
61176136
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
6118-
ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6137+
size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6138+
dst->use_counts[new_hash_pos] = src->use_counts[i];
61196139
}
61206140
}
61216141

tests/test-backend-ops.cpp

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ struct test_case {
382382
return 0;
383383
}
384384

385+
virtual bool run_whole_graph() { return false; }
386+
385387
ggml_cgraph * gf = nullptr;
386388
ggml_cgraph * gb = nullptr;
387389

@@ -574,7 +576,7 @@ struct test_case {
574576
GGML_UNUSED(index);
575577
};
576578

577-
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
579+
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr);
578580

579581
if (!cmp_ok) {
580582
printf("compare failed ");
@@ -1896,6 +1898,63 @@ struct test_rms_norm_back : public test_case {
18961898
}
18971899
};
18981900

1901+
// GGML_OP_RMS_NORM + GGML_OP_MUL
1902+
struct test_rms_norm_mul : public test_case {
1903+
const ggml_type type;
1904+
const std::array<int64_t, 4> ne;
1905+
const float eps;
1906+
1907+
std::string op_desc(ggml_tensor * t) override {
1908+
GGML_UNUSED(t);
1909+
return "RMS_NORM_MUL";
1910+
}
1911+
1912+
bool run_whole_graph() override { return true; }
1913+
1914+
std::string vars() override {
1915+
return VARS_TO_STR3(type, ne, eps);
1916+
}
1917+
1918+
test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
1919+
std::array<int64_t, 4> ne = {64, 5, 4, 3},
1920+
float eps = 1e-6f)
1921+
: type(type), ne(ne), eps(eps) {}
1922+
1923+
ggml_tensor * build_graph(ggml_context * ctx) override {
1924+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1925+
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
1926+
ggml_set_param(a);
1927+
ggml_set_name(a, "a");
1928+
ggml_set_param(b);
1929+
ggml_set_name(b, "b");
1930+
1931+
// Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
1932+
a = ggml_add(ctx, a, b);
1933+
ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
1934+
ggml_set_name(out, "out");
1935+
1936+
return out;
1937+
}
1938+
1939+
void initialize_tensors(ggml_context * ctx) override {
1940+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1941+
init_tensor_uniform(t, -10.f, 10.f);
1942+
}
1943+
}
1944+
1945+
double max_nmse_err() override {
1946+
return 1e-6;
1947+
}
1948+
1949+
float grad_eps() override {
1950+
return 1.0f;
1951+
}
1952+
1953+
bool grad_precise() override {
1954+
return true;
1955+
}
1956+
};
1957+
18991958
// GGML_OP_SSM_CONV
19001959
struct test_ssm_conv : public test_case {
19011960
const ggml_type type;
@@ -3736,6 +3795,7 @@ struct test_llama : public test_llm {
37363795
static constexpr float attn_factor = 1.0f;
37373796
static constexpr float beta_fast = 32.0f;
37383797
static constexpr float beta_slow = 1.0f;
3798+
bool fused;
37393799

37403800
std::string op_desc(ggml_tensor * t) override {
37413801
GGML_UNUSED(t);
@@ -3751,7 +3811,9 @@ struct test_llama : public test_llm {
37513811
return 2e-3;
37523812
}
37533813

3754-
test_llama(int n_tokens = 1)
3814+
bool run_whole_graph() override { return fused; }
3815+
3816+
test_llama(int n_tokens = 1, bool fused = false)
37553817
: test_llm({
37563818
/*n_vocab =*/ 32000,
37573819
/*n_embd =*/ 3200,
@@ -3763,7 +3825,9 @@ struct test_llama : public test_llm {
37633825
/*f_norm_eps =*/ 0.f,
37643826
/*f_norm_rms_eps =*/ 1e-5f,
37653827
/*n_tokens =*/ n_tokens,
3766-
}) {
3828+
})
3829+
, fused(fused)
3830+
{
37673831
}
37683832

37693833
ggml_tensor * build_graph(ggml_context * ctx) override {
@@ -4306,6 +4370,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
43064370
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
43074371
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
43084372
}
4373+
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
4374+
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
4375+
}
43094376

43104377
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
43114378

@@ -4677,6 +4744,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
46774744

46784745
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
46794746

4747+
test_cases.emplace_back(new test_llama(2, true));
46804748
// these tests are disabled to save execution time, but they can be handy for debugging
46814749
#if 0
46824750
test_cases.emplace_back(new test_llama(1));

0 commit comments

Comments
 (0)