@@ -983,17 +983,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
983
983
const bool v_trans = v->nb [1 ] > v->nb [2 ];
984
984
985
985
// split the batch into streams if needed
986
- const auto n_stream = cparams. kv_unified ? 1 : ubatch. n_seqs_unq ;
986
+ const auto n_stream = k-> ne [ 3 ] ;
987
987
988
988
q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_stream, n_stream);
989
989
990
990
q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
991
991
k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
992
992
v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
993
993
994
- const auto n_tokens = q->ne [1 ];
995
- const auto n_head = q->ne [2 ];
996
- const auto n_kv = k->ne [1 ];
994
+ const auto n_kv = k->ne [1 ];
997
995
998
996
ggml_tensor * cur;
999
997
@@ -1035,8 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1035
1033
#endif
1036
1034
}
1037
1035
1038
- // recombine streams
1039
- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
1036
+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*cur->ne [1 ], cur->ne [2 ]*cur->ne [3 ]);
1040
1037
} else {
1041
1038
ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1042
1039
@@ -1082,7 +1079,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1082
1079
cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
1083
1080
1084
1081
// recombine streams
1085
- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream );
1082
+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*cur-> ne [ 1 ], cur-> ne [ 2 ]*cur-> ne [ 3 ] );
1086
1083
1087
1084
if (!cparams.offload_kqv ) {
1088
1085
// all nodes between the KV store and the attention output are run on the CPU
@@ -1129,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
1129
1126
1130
1127
const auto & kq_mask = inp->get_kq_mask ();
1131
1128
1129
+ // [TAG_NO_CACHE_PAD]
1130
+ // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1131
+ assert (ubatch.equal_seqs == false );
1132
+
1132
1133
ggml_tensor * q = q_cur;
1133
1134
ggml_tensor * k = k_cur;
1134
1135
ggml_tensor * v = v_cur;
0 commit comments