Skip to content

Commit a15aa48

Browse files
jeffbolznvslaren
authored andcommitted
vulkan: Add fusion support for RMS_NORM+MUL (ggml-org#14366)
* vulkan: Add fusion support for RMS_NORM+MUL - Add a use_count to ggml_tensor, so we can detect if an output is used more than once. - Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor. - Add detection logic and basic fusion logic in ggml-vulkan. - Add some testing support for fusion. Rather than computing one node at a time, allow for computing the whole graph and just testing one node's results. Add rms_norm_mul tests and enable a llama test. * extract some common fusion logic * fix -Winconsistent-missing-override * move ggml_can_fuse to a common function * build fix * C and C++ versions of can_fuse * move use count to the graph to avoid data races and double increments when used in multiple threads * use hash table lookup to find node index * change use_counts to be indexed by hash table slot * minimize hash lookups style fixes * last node doesn't need single use. fix type. handle mul operands being swapped. * remove redundant parameter --------- Co-authored-by: slaren <slarengh@gmail.com>
1 parent a9478f0 commit a15aa48

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,10 @@ struct ggml_backend_vk_context {
991991

992992
vk_command_pool compute_cmd_pool;
993993
vk_command_pool transfer_cmd_pool;
994+
995+
// number of additional consecutive nodes that are being fused with the
996+
// node currently being processed
997+
uint32_t num_additional_fused_ops {};
994998
};
995999

9961000
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4896,6 +4896,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
48964896

48974897
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
48984898

4899+
test_cases.emplace_back(new test_llama(2, true));
4900+
// these tests are disabled to save execution time, but they can be handy for debugging
48994901
#if 0
49004902
// these tests are disabled to save execution time, sbut they can be handy for debugging
49014903
test_cases.emplace_back(new test_llama(2, true));

0 commit comments

Comments
 (0)