Skip to content

Commit 18f2d00

Browse files
committed
extract some common fusion logic
1 parent 8643fea commit 18f2d00

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

ggml/src/ggml-impl.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,28 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
467467
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
468468
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
469469

470+
// return true if the node's results are only used by N other nodes
471+
// and can be fused into their calculations.
472+
static inline bool ggml_can_fuse_node(const struct ggml_tensor * node, int32_t N) {
473+
// check the use count against how many we're replacing
474+
if (node->use_count != N) {
475+
return false;
476+
}
477+
478+
// if node is a view, some other node might be using the intermediate result
479+
// via the view source.
480+
if (node->view_src) {
481+
return false;
482+
}
483+
484+
// If the user requested output for the node, can't fuse
485+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
486+
return false;
487+
}
488+
489+
return true;
490+
}
491+
470492
#ifdef __cplusplus
471493
}
472494
#endif

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9698,15 +9698,14 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
96989698
}
96999699

97009700
// Returns true if nodes [i, i+1] are fusable RMS_NORM + MUL.
9701-
bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
9701+
static bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
97029702
ggml_tensor *norm = cgraph->nodes[i];
97039703

9704-
if (norm->op != GGML_OP_RMS_NORM || norm->use_count != 1) {
9704+
if (norm->op != GGML_OP_RMS_NORM) {
97059705
return false;
97069706
}
9707-
// if norm is a view, some other node might be using the intermediate result
9708-
// view the view source.
9709-
if (norm->view_src) {
9707+
9708+
if (!ggml_can_fuse_node(norm, 1)) {
97109709
return false;
97119710
}
97129711

@@ -9721,7 +9720,6 @@ bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97219720
// Since norm is the first operand of mul, it must be the same shape
97229721
GGML_ASSERT(ggml_are_same_shape(mul, norm));
97239722

9724-
// XXX TODO: Do we need a way to indicate that the user doesn't need the intermediate result?
97259723
return true;
97269724
}
97279725

0 commit comments

Comments
 (0)