Skip to content

Commit fb8150d

Browse files
committed
llama : rename attn_streams -> kv_unified
ggml-ci
1 parent 886d3f1 commit fb8150d

File tree

9 files changed

+29
-25
lines changed

9 files changed

+29
-25
lines changed

common/arg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,13 +1465,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14651465
}
14661466
).set_env("LLAMA_ARG_SWA_FULL"));
14671467
add_opt(common_arg(
1468-
{"--attn-streams", "-as"},
1468+
{"--kv-split", "-kvs"},
14691469
string_format("use multiple streams when computing the attention (default: %s)\n"
1470-
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.attn_streams ? "true" : "false"),
1470+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_split ? "true" : "false"),
14711471
[](common_params & params) {
1472-
params.attn_streams = true;
1472+
params.kv_split = true;
14731473
}
1474-
).set_env("LLAMA_ARG_ATTN_STREAMS"));
1474+
).set_env("LLAMA_ARG_KV_SPLIT"));
14751475
add_opt(common_arg(
14761476
{"--no-context-shift"},
14771477
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11571157
cparams.no_perf = params.no_perf;
11581158
cparams.op_offload = !params.no_op_offload;
11591159
cparams.swa_full = params.swa_full;
1160-
cparams.attn_streams = params.attn_streams;
1160+
cparams.kv_unified = !params.kv_split;
11611161

11621162
cparams.type_k = params.cache_type_k;
11631163
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ struct common_params {
330330
bool no_perf = false; // disable performance metrics
331331
bool ctx_shift = true; // context shift on inifinite text generation
332332
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
333-
bool attn_streams = false; // multi-stream attention and KV cache buffers
333+
bool kv_split = false; // disable unified KV cache
334334

335335
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
336336
bool use_mmap = true; // use mmap for faster loads

include/llama.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,9 @@ extern "C" {
334334
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
335335
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
336336
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
337-
338-
bool attn_streams; // if enabled, use multiple streams during the attention (determined by n_seq_max)
339-
// NOTE: this requires support for the ggml_set_rows() operator
340-
// this flag can improve the performance for parallel, multi-sequence use cases
337+
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
338+
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
339+
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
341340
};
342341

343342
// model quantization parameters

src/llama-context.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ llama_context::llama_context(
101101

102102
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
103103

104-
cparams.op_offload = params.op_offload;
105-
cparams.attn_streams = params.attn_streams;
104+
cparams.op_offload = params.op_offload;
105+
cparams.kv_unified = params.kv_unified;
106106

107107
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
108108

@@ -113,7 +113,7 @@ llama_context::llama_context(
113113
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
114114
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
115115
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
116-
LLAMA_LOG_INFO("%s: attn_streams = %s\n", __func__, cparams.attn_streams ? "true" : "false");
116+
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
117117
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
118118
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
119119

@@ -269,7 +269,7 @@ llama_context::llama_context(
269269

270270
// reserve worst-case graph
271271
if (!hparams.vocab_only && memory) {
272-
const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1;
272+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
273273
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
274274

275275
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -481,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) {
481481
throw std::runtime_error("failed to initialize memory context");
482482
}
483483

484-
const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1;
484+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
485485
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
486486

487487
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -740,7 +740,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
740740
const int64_t n_embd = hparams.n_embd;
741741

742742
// note: during encode, we always pass the full sequence starting from pos = 0
743-
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.attn_streams ? cparams.n_seq_max : LLAMA_MAX_SEQ, true)) {
743+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
744744
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
745745
return -1;
746746
}
@@ -907,7 +907,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
907907
// when computing embeddings, all tokens are output
908908
const bool output_all = cparams.embeddings;
909909

910-
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.attn_streams ? cparams.n_seq_max : LLAMA_MAX_SEQ, output_all)) {
910+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
911911
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
912912
return -1;
913913
}
@@ -2036,7 +2036,7 @@ void llama_context::opt_epoch_iter(
20362036
batch.logits [pos_batch] = true;
20372037
}
20382038

2039-
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.attn_streams ? cparams.n_seq_max : LLAMA_MAX_SEQ, true)) {
2039+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
20402040
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
20412041
return;
20422042
}
@@ -2195,7 +2195,7 @@ llama_context_params llama_context_default_params() {
21952195
/*.no_perf =*/ true,
21962196
/*.op_offload =*/ true,
21972197
/*.swa_full =*/ true,
2198-
/*.attn_streams =*/ false,
2198+
/*.kv_unified =*/ true,
21992199
};
22002200

22012201
return result;

src/llama-cparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct llama_cparams {
3333
bool no_perf;
3434
bool warmup;
3535
bool op_offload;
36-
bool attn_streams;
36+
bool kv_unified;
3737

3838
enum llama_pooling_type pooling_type;
3939

src/llama-graph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,7 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
11661166

11671167
const auto n_kv = mctx_cur->get_n_kv();
11681168
const auto n_tokens = ubatch.n_tokens;
1169-
const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;
1169+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
11701170

11711171
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
11721172
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
@@ -1371,7 +1371,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13711371

13721372
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
13731373

1374-
const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;
1374+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
13751375

13761376
{
13771377
const auto n_kv = mctx_cur->get_base()->get_n_kv();

src/llama-kv-cache-unified.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
195195
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
196196
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
197197

198+
if (!supports_set_rows && !unified) {
199+
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing LLAMA_SET_ROWS=1\n", __func__);
200+
supports_set_rows = 1;
201+
}
202+
198203
if (!supports_set_rows) {
199204
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
200205
}

src/llama-model.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16120,7 +16120,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1612016120

1612116121
uint32_t n_ctx_per_stream = cparams.n_ctx;
1612216122

16123-
if (cparams.attn_streams) {
16123+
if (!cparams.kv_unified) {
1612416124
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
1612516125
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
1612616126

@@ -16143,7 +16143,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1614316143
!cparams.flash_attn,
1614416144
cparams.offload_kqv,
1614516145
params.swa_full,
16146-
!cparams.attn_streams,
16146+
cparams.kv_unified,
1614716147
n_ctx_per_stream,
1614816148
cparams.n_seq_max,
1614916149
cparams.n_ubatch,
@@ -16158,7 +16158,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1615816158
params.type_v,
1615916159
!cparams.flash_attn,
1616016160
cparams.offload_kqv,
16161-
!cparams.attn_streams,
16161+
cparams.kv_unified,
1616216162
n_ctx_per_stream,
1616316163
cparams.n_seq_max,
1616416164
padding,

0 commit comments

Comments
 (0)