@@ -1034,17 +1034,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1034
1034
const bool v_trans = v->nb [1 ] > v->nb [2 ];
1035
1035
1036
1036
// split the batch into streams if needed
1037
- const auto n_stream = cparams. kv_unified ? 1 : ubatch. n_seqs_unq ;
1037
+ const auto n_stream = k-> ne [ 3 ] ;
1038
1038
1039
1039
q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_stream, n_stream);
1040
1040
1041
1041
q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
1042
1042
k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
1043
1043
v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
1044
1044
1045
- const auto n_tokens = q->ne [1 ];
1046
- const auto n_head = q->ne [2 ];
1047
- const auto n_kv = k->ne [1 ];
1045
+ const auto n_kv = k->ne [1 ];
1048
1046
1049
1047
ggml_tensor * cur;
1050
1048
@@ -1086,8 +1084,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1086
1084
#endif
1087
1085
}
1088
1086
1089
- // recombine streams
1090
- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
1087
+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*cur->ne [1 ], cur->ne [2 ]*cur->ne [3 ]);
1091
1088
} else {
1092
1089
ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1093
1090
@@ -1133,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1133
1130
cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
1134
1131
1135
1132
// recombine streams
1136
- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream );
1133
+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*cur-> ne [ 1 ], cur-> ne [ 2 ]*cur-> ne [ 3 ] );
1137
1134
1138
1135
if (!cparams.offload_kqv ) {
1139
1136
// all nodes between the KV store and the attention output are run on the CPU
@@ -1180,6 +1177,10 @@ ggml_tensor * llm_graph_context::build_attn(
1180
1177
1181
1178
const auto & kq_mask = inp->get_kq_mask ();
1182
1179
1180
+ // [TAG_NO_CACHE_PAD]
1181
+ // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1182
+ assert (ubatch.equal_seqs == false );
1183
+
1183
1184
ggml_tensor * q = q_cur;
1184
1185
ggml_tensor * k = k_cur;
1185
1186
ggml_tensor * v = v_cur;
0 commit comments