Skip to content

Commit 75422e8

Browse files
authored
graph : normalize Q, K, V shapes + sync cross attention (ggml-org#12449)
* graph : normalize Q, K, V shapes and add comments ggml-ci * context : synchronize before getting cross attention data * model : fix command-r attention norm check
1 parent bb115d2 commit 75422e8

File tree

4 files changed

+403
-247
lines changed

4 files changed

+403
-247
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,8 @@ int llama_context::encode(llama_batch & inp_batch) {
11431143
if (model.arch == LLM_ARCH_T5 && t_embd) {
11441144
//cross.t_embd = t_embd;
11451145

1146+
synchronize();
1147+
11461148
cross.n_embd = t_embd->ne[0];
11471149
cross.n_enc = t_embd->ne[1];
11481150
cross.v_embd.resize(cross.n_embd*cross.n_enc);

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
13781378
// note: storing RoPE-ed version of K in the KV cache
13791379
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
13801380

1381-
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
1381+
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
13821382

13831383
ggml_tensor * v_cache_view = nullptr;
13841384

src/llama-graph.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ struct llm_graph_context {
487487

488488
ggml_tensor * build_attn_mha(
489489
ggml_cgraph * gf,
490-
ggml_tensor * q,
491-
ggml_tensor * k,
492-
ggml_tensor * v,
490+
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
491+
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
492+
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
493493
ggml_tensor * kq_b,
494494
ggml_tensor * kq_mask,
495495
bool v_trans,
@@ -502,9 +502,9 @@ struct llm_graph_context {
502502
ggml_cgraph * gf,
503503
ggml_tensor * wo,
504504
ggml_tensor * wo_b,
505-
ggml_tensor * q_cur,
506-
ggml_tensor * k_cur,
507-
ggml_tensor * v_cur,
505+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
506+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
507+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
508508
ggml_tensor * kq_b,
509509
float kq_scale,
510510
int il) const;
@@ -516,9 +516,9 @@ struct llm_graph_context {
516516
ggml_cgraph * gf,
517517
ggml_tensor * wo,
518518
ggml_tensor * wo_b,
519-
ggml_tensor * q_cur,
520-
ggml_tensor * k_cur,
521-
ggml_tensor * v_cur,
519+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
520+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
521+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
522522
ggml_tensor * kq_b,
523523
float kq_scale,
524524
int il) const;
@@ -530,9 +530,9 @@ struct llm_graph_context {
530530
ggml_cgraph * gf,
531531
ggml_tensor * wo,
532532
ggml_tensor * wo_b,
533-
ggml_tensor * q_cur,
534-
ggml_tensor * k_cur,
535-
ggml_tensor * v_cur,
533+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
534+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
535+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
536536
ggml_tensor * kq_b,
537537
float kq_scale,
538538
int il) const;

0 commit comments

Comments
 (0)