Skip to content

Commit c9b4442

Browse files
ggerganovMinh141120
authored andcommitted
kv-cache : use ggml_set_rows (ggml-org#14285)
* kv-cache : use ggml_set_rows ggml-ci * graph : separate k and v indices ggml-ci * cont : remove redundant ifs ggml-ci * kv-cache : improve find_slot impl * kv-cache : bounds-check when accessing slot_info indices * kv-cache : add comments ggml-ci * ggml : add TODOs for adding GGML_OP_SET_ROWS support in the backends ggml-ci
1 parent f336e6e commit c9b4442

13 files changed

+437
-141
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
20862086
return false;
20872087
}
20882088
} break;
2089+
case GGML_OP_SET_ROWS:
2090+
{
2091+
// TODO: add support
2092+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
2093+
return false;
2094+
} break;
20892095
case GGML_OP_CPY: {
20902096
ggml_tensor *src = op->src[0];
20912097
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
22222222
default:
22232223
return false;
22242224
}
2225+
case GGML_OP_SET_ROWS:
2226+
{
2227+
// TODO: add support
2228+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
2229+
return false;
2230+
} break;
22252231
case GGML_OP_CPY:
22262232
case GGML_OP_DUP:
22272233
case GGML_OP_CONT:

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42854285
return false;
42864286
}
42874287
}
4288+
case GGML_OP_SET_ROWS:
4289+
{
4290+
// TODO: add support
4291+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
4292+
return false;
4293+
} break;
42884294
case GGML_OP_CPY:
42894295
{
42904296
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10342,6 +10342,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1034210342
return false;
1034310343
}
1034410344
} break;
10345+
case GGML_OP_SET_ROWS:
10346+
{
10347+
// TODO: add support
10348+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
10349+
return false;
10350+
} break;
1034510351
case GGML_OP_CONT:
1034610352
case GGML_OP_CPY:
1034710353
case GGML_OP_DUP:

src/llama-graph.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -284,19 +284,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
284284
}
285285

286286
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
287-
if (self_kq_mask) {
288-
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
289-
}
287+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
288+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
289+
290+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
290291
}
291292

292293
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
293-
if (self_kq_mask) {
294-
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
295-
}
294+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
295+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
296296

297-
if (self_kq_mask_swa) {
298-
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
299-
}
297+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
298+
299+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
300+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
301+
302+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
300303
}
301304

302305
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -337,9 +340,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
337340
}
338341

339342
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
340-
if (self_kq_mask) {
341-
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
342-
}
343+
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
344+
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
345+
346+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
343347

344348
const int64_t n_rs = mctx->get_recr()->get_n_rs();
345349

@@ -354,7 +358,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
354358
}
355359
}
356360

357-
void llm_graph_input_one::set_input(const llama_ubatch *) {
361+
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
362+
GGML_UNUSED(ubatch);
358363
GGML_ASSERT(one && ggml_nelements(one) == 1);
359364
float f_one = 1.0f;
360365
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@@ -1001,6 +1006,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10011006

10021007
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
10031008

1009+
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
1010+
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
1011+
10041012
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
10051013
//cb(inp->self_kq_mask, "KQ_mask", -1);
10061014
ggml_set_input(inp->self_kq_mask);
@@ -1202,8 +1210,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12021210

12031211
const auto n_kv = mctx_cur->get_n_kv();
12041212

1213+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1214+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1215+
12051216
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1206-
//cb(inp->self_kq_mask, "KQ_mask", -1);
12071217
ggml_set_input(inp->self_kq_mask);
12081218

12091219
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1234,8 +1244,11 @@ ggml_tensor * llm_graph_context::build_attn(
12341244

12351245
// store to KV cache
12361246
{
1237-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1238-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1247+
const auto & k_idxs = inp->get_k_idxs();
1248+
const auto & v_idxs = inp->get_v_idxs();
1249+
1250+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1251+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
12391252
}
12401253

12411254
const auto & kq_mask = inp->get_kq_mask();
@@ -1294,11 +1307,15 @@ ggml_tensor * llm_graph_context::build_attn(
12941307

12951308
// optionally store to KV cache
12961309
if (k_cur) {
1297-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1310+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1311+
1312+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
12981313
}
12991314

13001315
if (v_cur) {
1301-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1316+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1317+
1318+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
13021319
}
13031320

13041321
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1402,8 +1419,11 @@ ggml_tensor * llm_graph_context::build_attn(
14021419

14031420
// store to KV cache
14041421
{
1405-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1406-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1422+
const auto & k_idxs = inp->get_k_idxs();
1423+
const auto & v_idxs = inp->get_v_idxs();
1424+
1425+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1426+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
14071427
}
14081428

14091429
const auto & kq_mask = inp->get_kq_mask();
@@ -1438,8 +1458,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14381458
{
14391459
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14401460

1461+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1462+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1463+
14411464
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1442-
//cb(inp->self_kq_mask, "KQ_mask", -1);
14431465
ggml_set_input(inp->self_kq_mask);
14441466

14451467
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1450,8 +1472,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14501472

14511473
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
14521474

1475+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1476+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1477+
14531478
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1454-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14551479
ggml_set_input(inp->self_kq_mask_swa);
14561480

14571481
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,14 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
250250

251251
void set_input(const llama_ubatch * ubatch) override;
252252

253+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
254+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
255+
253256
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
254257

258+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
259+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
260+
255261
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
256262
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
257263

@@ -275,9 +281,19 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
275281

276282
void set_input(const llama_ubatch * ubatch) override;
277283

284+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
285+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
286+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
287+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
288+
278289
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
279290
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
280291

292+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
293+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
294+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
295+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
296+
281297
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
282298
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
283299
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
@@ -320,8 +336,14 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
320336

321337
ggml_tensor * s_copy; // I32 [kv_size]
322338

339+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
340+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
341+
323342
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
324343

344+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
345+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
346+
325347
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
326348
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
327349

@@ -337,7 +359,7 @@ class llm_graph_input_one : public llm_graph_input_i {
337359
llm_graph_input_one() {}
338360
virtual ~llm_graph_input_one() = default;
339361

340-
void set_input(const llama_ubatch *) override;
362+
void set_input(const llama_ubatch * ubatch) override;
341363

342364
ggml_tensor * one = nullptr; // F32
343365
};

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113
ubatches.push_back(std::move(ubatch)); // NOLINT
114114
}
115115

116-
auto heads_base = kv_base->prepare(ubatches);
117-
if (heads_base.empty()) {
116+
auto sinfos_base = kv_base->prepare(ubatches);
117+
if (sinfos_base.empty()) {
118118
break;
119119
}
120120

121-
auto heads_swa = kv_swa->prepare(ubatches);
122-
if (heads_swa.empty()) {
121+
auto sinfos_swa = kv_swa->prepare(ubatches);
122+
if (sinfos_swa.empty()) {
123123
break;
124124
}
125125

126-
assert(heads_base.size() == heads_swa.size());
126+
assert(sinfos_base.size() == sinfos_swa.size());
127127

128128
return std::make_unique<llama_kv_cache_unified_iswa_context>(
129-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
129+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
130130
} while (false);
131131

132132
// if it fails, try equal split
@@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144144
ubatches.push_back(std::move(ubatch)); // NOLINT
145145
}
146146

147-
auto heads_base = kv_base->prepare(ubatches);
148-
if (heads_base.empty()) {
147+
auto sinfos_base = kv_base->prepare(ubatches);
148+
if (sinfos_base.empty()) {
149149
break;
150150
}
151151

152-
auto heads_swa = kv_swa->prepare(ubatches);
153-
if (heads_swa.empty()) {
152+
auto sinfos_swa = kv_swa->prepare(ubatches);
153+
if (sinfos_swa.empty()) {
154154
break;
155155
}
156156

157-
assert(heads_base.size() == heads_swa.size());
157+
assert(sinfos_base.size() == sinfos_swa.size());
158158

159159
return std::make_unique<llama_kv_cache_unified_iswa_context>(
160-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
160+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
161161
} while (false);
162162

163163
// TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
220220

221221
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222222
llama_kv_cache_unified_iswa * kv,
223-
std::vector<uint32_t> heads_base,
224-
std::vector<uint32_t> heads_swa,
223+
slot_info_vec_t sinfos_base,
224+
slot_info_vec_t sinfos_swa,
225225
std::vector<llama_ubatch> ubatches) :
226226
ubatches(std::move(ubatches)),
227227
// note: here we copy the ubatches. not sure if this is ideal
228-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
228+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
229+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
230230
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231231
}
232232

src/llama-kv-cache-unified-iswa.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
7474

7575
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
7676
public:
77+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
78+
7779
// used for errors
7880
llama_kv_cache_unified_iswa_context(llama_memory_status status);
7981

@@ -90,8 +92,8 @@ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
9092
// used to create a batch processing context from a batch
9193
llama_kv_cache_unified_iswa_context(
9294
llama_kv_cache_unified_iswa * kv,
93-
std::vector<uint32_t> heads_base,
94-
std::vector<uint32_t> heads_swa,
95+
slot_info_vec_t sinfos_base,
96+
slot_info_vec_t sinfos_swa,
9597
std::vector<llama_ubatch> ubatches);
9698

9799
virtual ~llama_kv_cache_unified_iswa_context();

0 commit comments

Comments
 (0)