Skip to content

Commit fa2573e

Browse files
committed
llama : add "--attn-streams" flag
ggml-ci
1 parent 4a0ec58 commit fa2573e

File tree

9 files changed

+46
-19
lines changed

9 files changed

+46
-19
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14641464
params.swa_full = true;
14651465
}
14661466
).set_env("LLAMA_ARG_SWA_FULL"));
1467+
add_opt(common_arg(
1468+
{"--attn-streams", "-as"},
1469+
string_format("use multiple streams when computing the attention (default: %s)\n"
1470+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.attn_streams ? "true" : "false"),
1471+
[](common_params & params) {
1472+
params.attn_streams = true;
1473+
}
1474+
).set_env("LLAMA_ARG_ATTN_STREAMS"));
14671475
add_opt(common_arg(
14681476
{"--no-context-shift"},
14691477
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11571157
cparams.no_perf = params.no_perf;
11581158
cparams.op_offload = !params.no_op_offload;
11591159
cparams.swa_full = params.swa_full;
1160+
cparams.attn_streams = params.attn_streams;
11601161

11611162
cparams.type_k = params.cache_type_k;
11621163
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ struct common_params {
330330
bool no_perf = false; // disable performance metrics
331331
bool ctx_shift = true; // context shift on inifinite text generation
332332
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
333+
bool attn_streams = false; // multi-stream attention and KV cache buffers
333334

334335
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
335336
bool use_mmap = true; // use mmap for faster loads

include/llama.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ extern "C" {
374374
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
375375
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
376376
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
377+
378+
bool attn_streams; // if enabled, use multiple streams during the attention (determined by n_seq_max)
379+
// NOTE: this requires support for the ggml_set_rows() operator
380+
// this flag can improve the performance for parallel, multi-sequence use cases
377381
};
378382

379383
// model quantization parameters

src/llama-context.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ llama_context::llama_context(
3333
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
3434
}
3535

36-
const char * LLAMA_HT = getenv("LLAMA_HT");
37-
cparams.kv_unified = (LLAMA_HT && atoi(LLAMA_HT) > 0) ? false : true;
38-
3936
cparams.n_threads = params.n_threads;
4037
cparams.n_threads_batch = params.n_threads_batch;
4138
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -104,7 +101,8 @@ llama_context::llama_context(
104101

105102
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
106103

107-
cparams.op_offload = params.op_offload;
104+
cparams.op_offload = params.op_offload;
105+
cparams.attn_streams = params.attn_streams;
108106

109107
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
110108

@@ -115,6 +113,7 @@ llama_context::llama_context(
115113
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
116114
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
117115
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
116+
LLAMA_LOG_INFO("%s: attn_streams = %s\n", __func__, cparams.attn_streams ? "true" : "false");
118117
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
119118
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
120119

@@ -270,7 +269,7 @@ llama_context::llama_context(
270269

271270
// reserve worst-case graph
272271
if (!hparams.vocab_only && memory) {
273-
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
272+
const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1;
274273
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
275274

276275
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -314,6 +313,10 @@ llama_context::llama_context(
314313

315314
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
316315
{
316+
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
317+
//
318+
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
319+
//
317320
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
318321
if (!gf) {
319322
throw std::runtime_error("failed to allocate compute pp buffers");
@@ -478,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) {
478481
throw std::runtime_error("failed to initialize memory context");
479482
}
480483

481-
const uint32_t n_seqs = cparams.n_seq_max;
484+
const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1;
482485
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
483486

484487
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -2192,6 +2195,7 @@ llama_context_params llama_context_default_params() {
21922195
/*.no_perf =*/ true,
21932196
/*.op_offload =*/ true,
21942197
/*.swa_full =*/ true,
2198+
/*.attn_streams =*/ false,
21952199
};
21962200

21972201
return result;

src/llama-cparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct llama_cparams {
3333
bool no_perf;
3434
bool warmup;
3535
bool op_offload;
36-
bool kv_unified;
36+
bool attn_streams;
3737

3838
enum llama_pooling_type pooling_type;
3939

src/llama-graph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10011001
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
10021002

10031003
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1004-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1004+
const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;
10051005

10061006
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
10071007
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
@@ -1212,7 +1212,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12121212
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12131213

12141214
const auto n_kv = mctx_cur->get_n_kv();
1215-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1215+
const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;
12161216

12171217
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
12181218
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
@@ -1459,7 +1459,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14591459

14601460
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14611461

1462-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1462+
const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;
14631463

14641464
{
14651465
const auto n_kv = mctx_cur->get_base()->get_n_kv();

src/llama-kv-cache-unified.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,14 +317,23 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
317317
// TODO: do we need synchronization here?
318318
}
319319

320-
// TODO: support this:
321-
GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift");
322-
323320
v_cells[s1].reset();
324321
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
325322
if (v_cells[s0].seq_has(i, seq_id_src)) {
326-
v_cells[s1].pos_set(i, v_cells[s0].pos_get(i));
323+
llama_pos pos = v_cells[s0].pos_get(i);
324+
llama_pos shift = v_cells[s0].get_shift(i);
325+
326+
if (shift != 0) {
327+
pos -= shift;
328+
assert(pos >= 0);
329+
}
330+
331+
v_cells[s1].pos_set(i, pos);
327332
v_cells[s1].seq_add(i, seq_id_dst);
333+
334+
if (shift != 0) {
335+
v_cells[s1].pos_add(i, shift);
336+
}
328337
}
329338
}
330339

@@ -1057,7 +1066,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
10571066
// TODO: fallback to old ggml_cpy() method for backwards compatibility
10581067
// will be removed when ggml_set_rows() is adopted by all backends
10591068

1060-
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported");
1069+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
10611070

10621071
ggml_tensor * k_view = ggml_view_1d(ctx, k,
10631072
n_tokens*n_embd_k_gqa,
@@ -1101,7 +1110,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
11011110
// TODO: fallback to old ggml_cpy() method for backwards compatibility
11021111
// will be removed when ggml_set_rows() is adopted by all backends
11031112

1104-
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported");
1113+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
11051114

11061115
ggml_tensor * v_view = nullptr;
11071116

src/llama-model.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14712,7 +14712,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1471214712

1471314713
uint32_t n_ctx_per_stream = cparams.n_ctx;
1471414714

14715-
if (!cparams.kv_unified) {
14715+
if (cparams.attn_streams) {
1471614716
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
1471714717
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
1471814718

@@ -14735,7 +14735,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1473514735
!cparams.flash_attn,
1473614736
cparams.offload_kqv,
1473714737
params.swa_full,
14738-
cparams.kv_unified,
14738+
!cparams.attn_streams,
1473914739
n_ctx_per_stream,
1474014740
cparams.n_seq_max,
1474114741
cparams.n_ubatch,
@@ -14750,7 +14750,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1475014750
params.type_v,
1475114751
!cparams.flash_attn,
1475214752
cparams.offload_kqv,
14753-
cparams.kv_unified,
14753+
!cparams.attn_streams,
1475414754
n_ctx_per_stream,
1475514755
cparams.n_seq_max,
1475614756
padding,

0 commit comments

Comments
 (0)