Why don't we need a ggml_cont
for v
in llm_graph_context::build_attn_mha
?
#14351
Unanswered
AgainstEntropy
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi community,
I'm going over with
llama.cpp
's implementation of attention and KV cache, and I noticed something potentially inefficient whenv_trans = true
.In
llm_graph_context::build_attn
llama.cpp/src/llama-graph.cpp
Lines 1217 to 1238 in ce82bd0
v_cache
(layers[ikv].v
) is first updated by copyingv_cur
into a transposed view ofv_cache
viallama_kv_cache_unified::cpy_v
llama.cpp/src/llama-kv-cache-unified.cpp
Lines 779 to 784 in ce82bd0
This ensures the first axis (n_tokens) is contiguous in memory, but the other axes are not.
Then, when
v_cache
is fetched viallama_kv_cache_unified::get_v
and passed tobuild_attn_mha
llama.cpp/src/llama-kv-cache-unified.cpp
Lines 741 to 746 in ce82bd0
ne = [n_tokens, n_head_kv, n_embd_head_v, 1]
andnb = [e, e * n_ctx * n_embd_head_v, e * n_ctx, e * n_ctx * n_embd_head_v]
(with
e = element_size(v)
), since the underlying storage has shape[n_embd_k_gqa, n_ctx]
.llama.cpp/src/llama-kv-cache-unified.cpp
Line 96 in ce82bd0
llama.cpp/src/llama-model.cpp
Lines 13805 to 13816 in 9eaa51e
ggml_permute(0, 2, 1, 3)
, yielding:ne = [n_tokens, n_embd_head_v, n_head_kv, 1]
andnb = [e, e * n_ctx, e * n_ctx * n_embd_head_v, e * n_ctx * n_embd_head_v]
.llama.cpp/src/llama-graph.cpp
Line 1024 in ce82bd0
v
is used inggml_mul_mat(ctx0, v, kq)
.Here’s my concern: after the permutation,
v
is not fully contiguous, especially along the 2nd and 3rd axes. This likely leads to a non-contiguousmul_mat
, which can hurt performance.Would adding
ggml_cont(v)
beforeggml_mul_mat
improve performance? It would make all axes contiguous:ne = [n_tokens, n_embd_head_v, n_head_kv, 1]
andnb = [e, e * n_tokens, e * n_tokens * n_embd_head_v, e * n_tokens * n_embd_head_v]
While
ggml_cont
may introduce a copy, I suspect the cost is less than that of an inefficientmul_mat
. Or isv
already made contiguous elsewhere before the matmul?Am I missing something here? I’d appreciate any insights or corrections—thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions