Skip to content

Commit e7f94f8

Browse files
committed
llama : update llama_kv_self API
ggml-ci
1 parent fb74024 commit e7f94f8

File tree

30 files changed

+386
-203
lines changed

30 files changed

+386
-203
lines changed

common/common.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -893,9 +893,7 @@ struct common_init_result common_init_from_params(common_params & params) {
893893
return iparams;
894894
}
895895

896-
llama_kv_cache * kv = llama_get_kv_cache(lctx);
897-
898-
if (params.ctx_shift && !llama_kv_cache_can_shift(kv)) {
896+
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
899897
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
900898
params.ctx_shift = false;
901899
}
@@ -1000,7 +998,7 @@ struct common_init_result common_init_from_params(common_params & params) {
1000998
if (llama_model_has_decoder(model)) {
1001999
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
10021000
}
1003-
llama_kv_cache_clear(kv);
1001+
llama_kv_self_clear(lctx);
10041002
llama_synchronize(lctx);
10051003
llama_perf_context_reset(lctx);
10061004
}

common/speculative.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,8 @@ llama_tokens common_speculative_gen_draft(
171171
llama_tokens result;
172172
result.reserve(params.n_draft);
173173

174-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
175-
176174
if (reuse_n == 0) {
177-
llama_kv_cache_clear(kv);
175+
llama_kv_self_clear(ctx);
178176

179177
prompt.clear();
180178
} else {
@@ -193,14 +191,14 @@ llama_tokens common_speculative_gen_draft(
193191
}
194192

195193
if (reuse_i > 0) {
196-
llama_kv_cache_seq_rm (kv, 0, 0, reuse_i);
197-
llama_kv_cache_seq_add(kv, 0, reuse_i, -1, -reuse_i);
194+
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
195+
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
198196

199197
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
200198
}
201199

202200
if (reuse_n < (int) prompt.size()) {
203-
llama_kv_cache_seq_rm (kv, 0, reuse_n, -1);
201+
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
204202

205203
prompt.erase(prompt.begin() + reuse_n, prompt.end());
206204
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ int main(int argc, char ** argv) {
5757
return 1;
5858
}
5959

60-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
61-
6260
const int32_t n_kv_max = llama_n_ctx(ctx);
6361

6462
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
@@ -134,7 +132,7 @@ int main(int argc, char ** argv) {
134132

135133
const auto t_pp_start = ggml_time_us();
136134

137-
llama_kv_cache_clear(kv);
135+
llama_kv_self_clear(ctx);
138136

139137
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
140138
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -143,7 +141,7 @@ int main(int argc, char ** argv) {
143141

144142
if (is_pp_shared) {
145143
for (int32_t i = 1; i < pl; ++i) {
146-
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
144+
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
147145
}
148146
}
149147

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ if llama_decode(context, batch) != 0 {
111111
}
112112

113113
for i in 1 ..< n_parallel {
114-
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
114+
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
115115
}
116116

117117
if n_parallel > 1 {

examples/cvector-generator/cvector-generator.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
342342
}
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
346-
llama_kv_cache_clear(kv);
345+
llama_kv_self_clear(ctx);
347346
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
348347
fprintf(stderr, "%s : failed to eval\n", __func__);
349348
return false;

examples/embedding/embedding.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3434

3535
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
3636
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
37-
const llama_model * model = llama_get_model(ctx);
38-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
37+
const struct llama_model * model = llama_get_model(ctx);
3938

4039
// clear previous kv_cache values (irrelevant for embeddings)
41-
llama_kv_cache_clear(kv);
40+
llama_kv_self_clear(ctx);
4241

4342
// run model
4443
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1313
const llama_model * model = llama_get_model(ctx);
1414
const llama_vocab * vocab = llama_model_get_vocab(model);
1515

16-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
17-
1816
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
1917

2018
for (uint64_t i = 0; i < sentences.size(); i++) {
@@ -47,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4745
}
4846

4947
// clear previous kv_cache values (irrelevant for embeddings)
50-
llama_kv_cache_clear(kv);
48+
llama_kv_self_clear(ctx);
5149
llama_set_embeddings(ctx, true);
5250
llama_set_causal_attn(ctx, false);
5351

@@ -102,11 +100,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
102100
const llama_model * model = llama_get_model(ctx);
103101
const llama_vocab * vocab = llama_model_get_vocab(model);
104102

105-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
106-
107103
llama_token eos_token = llama_vocab_eos(vocab);
108104

109-
llama_kv_cache_clear(kv);
105+
llama_kv_self_clear(ctx);
110106
llama_set_embeddings(ctx, false);
111107
llama_set_causal_attn(ctx, true);
112108

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,6 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
431431
const llama_model * model = llama_get_model(ctx);
432432
const llama_vocab * vocab = llama_model_get_vocab(model);
433433

434-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
435-
436434
const bool add_bos = llama_vocab_get_add_bos(vocab);
437435
const int n_ctx = llama_n_ctx(ctx);
438436

@@ -499,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
499497
const auto t_start = std::chrono::high_resolution_clock::now();
500498

501499
// clear the KV cache
502-
llama_kv_cache_clear(kv);
500+
llama_kv_self_clear(ctx);
503501

504502
llama_batch batch = llama_batch_init(n_batch, 0, 1);
505503

examples/infill/infill.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ int main(int argc, char ** argv) {
139139
return 1;
140140
}
141141

142-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
143-
144142
const llama_vocab * vocab = llama_model_get_vocab(model);
145143

146144
const int n_ctx_train = llama_model_n_ctx_train(model);
@@ -334,8 +332,8 @@ int main(int argc, char ** argv) {
334332
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
335333
n_past, n_left, n_ctx, params.n_keep, n_discard);
336334

337-
llama_kv_cache_seq_rm (kv, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
338-
llama_kv_cache_seq_add(kv, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
335+
llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336+
llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
339337

340338
n_past -= n_discard;
341339

examples/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,11 +1546,9 @@ int main(int argc, char ** argv) {
15461546
return 1;
15471547
}
15481548

1549-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
1550-
15511549
test t(inst, lmodel, ctx);
15521550

1553-
llama_kv_cache_clear(kv);
1551+
llama_kv_self_clear(ctx);
15541552

15551553
// cool off before the test
15561554
if (params.delay) {
@@ -1590,7 +1588,7 @@ int main(int argc, char ** argv) {
15901588
}
15911589

15921590
for (int i = 0; i < params.reps; i++) {
1593-
llama_kv_cache_clear(kv);
1591+
llama_kv_self_clear(ctx);
15941592

15951593
uint64_t t_start = get_time_ns();
15961594

0 commit comments

Comments
 (0)