Skip to content

Commit 52b9007

Browse files
committed
llama : add "virtual sequences"
ggml-ci
1 parent 36f8e20 commit 52b9007

15 files changed

+492
-204
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ typedef struct {
230230
uint64_t nb22;
231231
uint64_t nb23;
232232
uint64_t nb31;
233+
uint64_t nb32;
233234
int32_t ne1;
234235
int32_t ne2;
235236
float scale;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4882,6 +4882,7 @@ static bool ggml_metal_encode_node(
48824882
/*.nb22 =*/ nb22,
48834883
/*.nb23 =*/ nb23,
48844884
/*.nb31 =*/ nb31,
4885+
/*.nb32 =*/ nb32,
48854886
/*.ne1 =*/ ne1,
48864887
/*.ne2 =*/ ne2,
48874888
/*.scale =*/ scale,

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext(
36453645
// load the mask in shared memory
36463646
#pragma unroll(Q)
36473647
for (short j = 0; j < Q; ++j) {
3648-
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
3648+
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + iq3*args.nb32);
36493649

36503650
const float m = pm[ic + tiisg];
36513651

@@ -4131,7 +4131,7 @@ kernel void kernel_flash_attn_ext_vec(
41314131
const bool has_mask = mask != q;
41324132

41334133
// pointer to the mask
4134-
device const half * pm = (device const half *) (mask + iq1*args.nb31);
4134+
device const half * pm = (device const half *) (mask + iq1*args.nb31 + iq3*args.nb32);
41354135

41364136
float slope = 1.0f;
41374137

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3526,7 +3526,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
35263526
if (mask) {
35273527
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
35283528
GGML_ASSERT(ggml_is_contiguous(mask));
3529-
GGML_ASSERT(ggml_is_matrix(mask));
3529+
GGML_ASSERT(ggml_is_3d(mask));
35303530
GGML_ASSERT(mask->ne[0] == a->ne[0]);
35313531
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
35323532
}
@@ -4504,7 +4504,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
45044504

45054505
if (mask) {
45064506
GGML_ASSERT(ggml_is_contiguous(mask));
4507-
GGML_ASSERT(mask->ne[2] == 1);
4507+
GGML_ASSERT(mask->ne[2] == q->ne[3]);
45084508
GGML_ASSERT(mask->ne[3] == 1);
45094509
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
45104510
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");

src/llama-batch.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
460460
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
461461
std::vector<seq_set_t> cur_seq_set;
462462

463+
llama_seq_id last_seq_id = -1;
464+
463465
// determine the non-overlapping sequence sets participating in this ubatch
464466
for (int32_t i = 0; i < batch.n_tokens; ++i) {
465467
if (used[i]) {
@@ -476,9 +478,14 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
476478
}
477479
}
478480

481+
// accept only increasing sequence ids
482+
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
483+
479484
if (add) {
480485
cur_seq_set.push_back(seq_set[i]);
481486

487+
last_seq_id = batch.seq_id[i][0];
488+
482489
if (cur_seq_set.size() > n_ubatch) {
483490
break;
484491
}

src/llama-context.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ llama_context::llama_context(
3333
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
3434
}
3535

36+
const char * LLAMA_HT = getenv("LLAMA_HT");
37+
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
38+
3639
cparams.n_threads = params.n_threads;
3740
cparams.n_threads_batch = params.n_threads_batch;
3841
cparams.yarn_ext_factor = params.yarn_ext_factor;

src/llama-cparams.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ struct llama_cparams {
1111
uint32_t n_batch;
1212
uint32_t n_ubatch;
1313
uint32_t n_seq_max;
14-
int n_threads; // number of threads to use for generation
15-
int n_threads_batch; // number of threads to use for batch processing
14+
uint32_t n_seq_virt;
15+
int32_t n_threads; // number of threads to use for generation
16+
int32_t n_threads_batch; // number of threads to use for batch processing
1617

1718
float rope_freq_base;
1819
float rope_freq_scale;

src/llama-graph.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10311031
float kq_scale) const {
10321032
const bool v_trans = v->nb[1] > v->nb[2];
10331033

1034+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1035+
1036+
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
1037+
10341038
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
10351039
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
10361040
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
@@ -1079,7 +1083,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10791083
#endif
10801084
}
10811085

1082-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1086+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
10831087
} else {
10841088
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
10851089

@@ -1124,7 +1128,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11241128

11251129
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
11261130

1127-
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1131+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
11281132

11291133
if (!cparams.offload_kqv) {
11301134
// all nodes between the KV store and the attention output are run on the CPU
@@ -1203,11 +1207,12 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12031207
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12041208

12051209
const auto n_kv = mctx_cur->get_n_kv();
1210+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
12061211

12071212
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
12081213
ggml_set_input(inp->self_kv_idxs);
12091214

1210-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1215+
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
12111216
//cb(inp->self_kq_mask, "KQ_mask", -1);
12121217
ggml_set_input(inp->self_kq_mask);
12131218

src/llama-graph.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
254254
// TODO: should this be I64?
255255
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
256256

257-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
258-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
257+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
258+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
259259

260260
const llama_hparams & hparams;
261261
const llama_cparams & cparams;
@@ -285,10 +285,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
285285
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
286286
ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch]
287287

288-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
289-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
290-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
291-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
288+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
289+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
290+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
291+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
292292

293293
const llama_hparams & hparams;
294294
const llama_cparams & cparams;

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
2020
bool swa_full,
2121
uint32_t kv_size,
2222
uint32_t n_seq_max,
23+
uint32_t n_seq_virt,
2324
uint32_t n_ubatch,
24-
uint32_t n_pad) : hparams(model.hparams) {
25+
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
2526
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
2627
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
2728

2829
const uint32_t size_base = kv_size;
2930

30-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
3132

3233
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
3334
if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
4142

4243
kv_base = std::make_unique<llama_kv_cache_unified>(
4344
model, std::move(filter_base), type_k, type_v,
44-
v_trans, offload, size_base, n_seq_max, n_pad,
45+
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
4546
0, LLAMA_SWA_TYPE_NONE);
4647

4748
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
4849

4950
kv_swa = std::make_unique<llama_kv_cache_unified>(
5051
model, std::move(filter_swa), type_k, type_v,
51-
v_trans, offload, size_swa, n_seq_max, n_pad,
52+
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
5253
hparams.n_swa, hparams.swa_type);
5354
}
5455

@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
100101

101102
// first try simple split
102103
do {
104+
if (n_seq_virt > 1) {
105+
// requires equal splits
106+
break;
107+
}
108+
103109
balloc.split_reset();
104110

105111
std::vector<llama_ubatch> ubatches;

0 commit comments

Comments
 (0)