Skip to content

llama : add high-throughput mode #14363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 16, 2025
Merged

llama : add high-throughput mode #14363

merged 17 commits into from
Jul 16, 2025

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Jun 24, 2025

target #14285

Overview

Improve multi-sequence decoding performance by avoiding the cross-sequence attention compute.

Note

To enable this functionality, there is a temporary requirement LLAMA_SET_ROWS=1 to be set in your environment variable. In the future, this will become the default. See below for more info.
If you try to use "split KV" cache and haven't added LLAMA_SET_ROWS=1 you will see the following warning:

image

Description

One significant drawback of the unified KV cache is that it leads to performing a lot of unnecessary computation in the attention when the unified buffer is shared between many large independent sequences. The reason is that we have to view this buffer continuously and therefore we end up computing large potions of "cross-sequence attention" which we then simply discard.

With this change, we add option to split the unified KV cache buffer into multiple buffers - one for each sequence. This decouples the sequences from each other and improves the performance and memory usage of the attention when more than one sequence is used. To achieve that, when the batch reaches the attention, we split it into multiple "streams":

llama.cpp/src/llama-graph.cpp

Lines 1035 to 1044 in c96c48c

// split the batch into streams if needed
const auto n_stream = k->ne[3];
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);

Each stream has its own KV cache buffer and thus no longer "sees" the rest of the other streams - it attends only to the tokens that belong to the same stream.

With this approach we now have 2 modes:

  • The vanilla "unified" approach which we always used until now - all sequences are assigned to a single stream
  • The new "split" approach - each sequence is assigned to a separate stream

The new "split" mode is enabled by default. However it requires the LLAMA_SET_ROWS=1 environment variable to be set. Otherwise, a warning will be printed and the context will fallback to "unified" mode. In the future, after there is enough ggml_set_rows() coverage in the backends (#14661) this will become the default mode.

To force the old "unified" mode, use --kv-unified CLI arg.

API Changes

  • Add bool llama_context_params::kv_unified. Default is false

llama.cpp/include/llama.h

Lines 336 to 340 in fb8150d

// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
};

Testing

Use LLAMA_SET_ROWS=1 llama-[command] ...

Qwen 2.5 Coder 3B Q8_0, M2 Ultra

# master
make -j && ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.00.604.032 I llama_kv_cache_unified:      Metal KV buffer size =  4680.00 MiB
0.00.953.209 I llama_kv_cache_unified: size = 4680.00 MiB (133120 cells,  36 layers, 32 seqs), K (f16): 2340.00 MiB, V (f16): 2340.00 MiB
0.01.016.945 I llama_context:      Metal compute buffer size =  1624.05 MiB
0.01.016.947 I llama_context:        CPU compute buffer size =  1056.05 MiB
0.01.016.947 I llama_context: graph nodes  = 1195
0.01.016.947 I llama_context: graph splits = 2
main: n_kv_max = 133120, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.403 |   729.71 |    1.403 |   729.66 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.381 |   741.44 |    1.381 |   741.37 |
|   512 |     32 |   32 |  17408 |    5.320 |  3079.72 |    2.052 |   498.98 |    7.372 |  2361.33 |
|  1024 |     32 |   32 |  33792 |   11.632 |  2817.15 |    2.715 |   377.16 |   14.347 |  2355.40 |
|  2048 |     32 |   32 |  66560 |   27.419 |  2390.20 |    4.052 |   252.73 |   31.470 |  2115.00 |
|  4096 |     32 |   32 | 132096 |   71.549 |  1831.92 |    6.664 |   153.66 |   78.213 |  1688.93 |


# PR
make -j && LLAMA_SET_ROWS=1 ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.00.584.467 I llama_kv_cache_unified:      Metal KV buffer size =  4896.00 MiB
0.00.952.799 I llama_kv_cache_unified: size = 4896.00 MiB (  4352 cells,  36 layers, 32/32 seqs), K (f16): 2448.00 MiB, V (f16): 2448.00 MiB
0.01.002.436 I llama_context:      Metal compute buffer size =  1219.00 MiB
0.01.002.438 I llama_context:        CPU compute buffer size =    50.05 MiB
0.01.002.438 I llama_context: graph nodes  = 1231
0.01.002.438 I llama_context: graph splits = 2
main: n_kv_max = 139264, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.339 |   764.92 |    1.339 |   764.85 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.332 |   768.79 |    1.332 |   768.69 |
|   512 |     32 |   32 |  17408 |    4.903 |  3341.42 |    1.499 |   682.93 |    6.403 |  2718.84 |
|  1024 |     32 |   32 |  33792 |   10.057 |  3258.12 |    1.569 |   652.46 |   11.627 |  2906.40 |
|  2048 |     32 |   32 |  66560 |   21.213 |  3089.47 |    1.754 |   583.79 |   22.967 |  2898.10 |
|  4096 |     32 |   32 | 132096 |   46.713 |  2805.91 |    2.107 |   486.09 |   48.819 |  2705.81 |

Geamma 3 4B Q8_0, M2 Ultra

# master
make -j && ./bin/llama-batched-bench -m ../models/gemma-3-4b/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.01.609.907 I llama_kv_cache_unified_iswa: creating non-SWA KV cache, size = 133120 cells
0.01.703.014 I llama_kv_cache_unified:      Metal KV buffer size =  2600.00 MiB
0.01.902.274 I llama_kv_cache_unified: size = 2600.00 MiB (133120 cells,   5 layers, 32 seqs), K (f16): 1300.00 MiB, V (f16): 1300.00 MiB
0.01.902.278 I llama_kv_cache_unified_iswa: creating     SWA KV cache, size = 34816 cells
0.02.040.114 I llama_kv_cache_unified:      Metal KV buffer size =  3944.00 MiB
0.02.325.408 I llama_kv_cache_unified: size = 3944.00 MiB ( 34816 cells,  29 layers, 32 seqs), K (f16): 1972.00 MiB, V (f16): 1972.00 MiB
0.02.403.614 I llama_context:      Metal compute buffer size =  2068.00 MiB
0.02.403.616 I llama_context:        CPU compute buffer size =  1332.09 MiB
0.02.403.617 I llama_context: graph nodes  = 1335
0.02.403.617 I llama_context: graph splits = 2
main: n_kv_max = 133120, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.843 |   555.52 |    1.844 |   555.44 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.800 |   569.00 |    1.800 |   568.94 |
|   512 |     32 |   32 |  17408 |    6.341 |  2583.88 |    3.601 |   284.33 |    9.942 |  1750.90 |
|  1024 |     32 |   32 |  33792 |   13.832 |  2369.03 |    5.442 |   188.18 |   19.273 |  1753.29 |
|  2048 |     32 |   32 |  66560 |   31.034 |  2111.78 |    6.343 |   161.43 |   37.377 |  1780.77 |
|  4096 |     32 |   32 | 132096 |   69.326 |  1890.65 |    7.456 |   137.33 |   76.783 |  1720.39 |

# PR
make -j && LLAMA_SET_ROWS=1 ./bin/llama-batched-bench -m ../models/gemma-3-4b/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.00.505.130 I llama_kv_cache_unified_iswa: creating non-SWA KV cache, size = 4352 cells
0.00.603.948 I llama_kv_cache_unified:      Metal KV buffer size =  2720.00 MiB
0.00.813.515 I llama_kv_cache_unified: size = 2720.00 MiB (  4352 cells,   5 layers, 32/32 seqs), K (f16): 1360.00 MiB, V (f16): 1360.00 MiB
0.00.813.520 I llama_kv_cache_unified_iswa: creating     SWA KV cache, size = 3072 cells
0.01.198.824 I llama_kv_cache_unified:      Metal KV buffer size = 11136.00 MiB
0.01.986.031 I llama_kv_cache_unified: size = 11136.00 MiB (  3072 cells,  29 layers, 32/32 seqs), K (f16): 5568.00 MiB, V (f16): 5568.00 MiB
0.02.059.335 I llama_context:      Metal compute buffer size =  2068.00 MiB
0.02.059.340 I llama_context:        CPU compute buffer size =    78.09 MiB
0.02.059.340 I llama_context: graph nodes  = 1369
0.02.059.340 I llama_context: graph splits = 2
main: n_kv_max = 139264, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.577 |   649.36 |    1.577 |   649.26 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.568 |   652.99 |    1.568 |   652.86 |
|   512 |     32 |   32 |  17408 |    5.884 |  2784.73 |    1.769 |   578.77 |    7.653 |  2274.73 |
|  1024 |     32 |   32 |  33792 |   12.261 |  2672.46 |    1.874 |   546.44 |   14.135 |  2390.61 |
|  2048 |     32 |   32 |  66560 |   25.831 |  2537.12 |    1.962 |   522.01 |   27.793 |  2394.89 |
|  4096 |     32 |   32 | 132096 |   54.077 |  2423.79 |    2.065 |   496.00 |   56.142 |  2352.90 |

Using a more real-world example with llama-parallel:

# master
make -j && ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 32 -ns 128 -s 1 -c 16384 -fa

# PR
make -j && LLAMA_SET_ROWS=1 ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 32 -ns 128 -s 1 -c 4096 -fa

TODO

  • FA path
  • Non-FA path
  • Metal FA
  • Metal non-FA
  • CPU FA
  • CPU non-FA
  • ggml_soft_max_ext() support for virtual sequences
  • llama_memory_seq_cp support for virtual sequences
  • iSWA
  • split_equal support sequential ids
  • CUDA
  • Vulkan
  • etc.
  • more consistent sequence/virtual sequence naming
  • better term than "virtual sequence"?
  • env LLAMA_HT become regular compute parameter
  • Fix n_ctx meaning (total vs per-sequence)
  • Check input batch for no coupled sequences when HT is on
  • Require n_embd_v_gqa(il) == const when FA is off (no longer needed)
  • Save/load state

Next PRs

  • Optimize parallel encoding via (split_equal + padding) and stream split [TAG_NO_CACHE_PAD]
  • Disable and remove the defrag code when ggml_set_rows() is fully adopted
  • Add option to llama-parallel to use different RNG seeds for the different clients

@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Jun 24, 2025
@JohannesGaessler
Copy link
Collaborator

Right now I am comparatively less busy with my PhD so it would be a good time for me to write CUDA code that is still missing, if there is any.

@ggerganov
Copy link
Member Author

ggerganov commented Jun 24, 2025

For now, these are the necessary CUDA changes:

  • Add ggml_set_rows() support (need PR towards ggml : add ggml_set_rows #14274, can already start implementing this)
  • Extend ggml_flash_attn_ext() to support n_seq dim if it does not yet:
// old
    // q:    [n_embd_k, n_batch,     n_head,    1]
    // k:    [n_embd_k, n_kv,        n_head_kv, 1]
    // v:    [n_embd_v, n_kv,        n_head_kv, 1] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   1] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

// new - supports `n_seq` dimension:
    // q:    [n_embd_k, n_batch,     n_head,    n_seq]
    // k:    [n_embd_k, n_kv,        n_head_kv, n_seq]
    // v:    [n_embd_v, n_kv,        n_head_kv, n_seq] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, n_seq,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   n_seq] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

CPU might also need to be extended (not sure yet)

  • Extend ggml_soft_max_ext to support n_seq dim if it does not yet in a similar way. Also not sure about the CPU state.

Edit: the CPU versions of ggml_soft_max_ext() and ggml_flash_attn_ext() are now correct and can be used as a reference.

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from ab2a2bb to 1b74b9d Compare June 24, 2025 17:24
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch 3 times, most recently from c246784 to 06bb08a Compare June 27, 2025 14:35
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch 3 times, most recently from 82277da to 4534123 Compare June 30, 2025 14:08
@ggerganov ggerganov mentioned this pull request Jul 1, 2025
5 tasks
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch from 2f577c5 to 30b4d4e Compare July 2, 2025 12:49
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from 6179578 to dfceb01 Compare July 2, 2025 18:20
Base automatically changed from gg/kv-cache-use-set-rows to master July 3, 2025 07:53
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch 2 times, most recently from eb5856c to ee0f729 Compare July 3, 2025 08:12
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from ee0f729 to deae7cd Compare July 3, 2025 08:53
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch 2 times, most recently from 988d0cd to dbcfcaa Compare July 3, 2025 12:11
v_cells[s].resize(kv_size);
}

// by default, all sequence ids are mapped to the 0th virtual sequence
Copy link
Collaborator

@compilade compilade Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to understand the purpose of virtual sequences.

  • Is it to make the unified cache not unified?
    • Should it be a separate cache type instead?
  • why is n_seq_virt a number and not a bool of whether or not the cache is unified?
    • Is it to eventually allow n_seq_max % n_seq_virt == 0 for a partially-unified cache?
  • Are virtual sequences intended to be used with other types of caches eventually (e.g. recurrent)?
    • The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding).

Copy link
Member Author

@ggerganov ggerganov Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today I found a better term instead of "virtual sequences": "streams". So I'll use "streams" here and will update the code later today or tomorrow.

Is it to make the unified cache not unified?

Roughly yes. The user will be able to select between unified (i.e. single stream) or non-unified (multiple streams). Each mode has advantages in different scenarios. Single stream is good when the sequences share large common prefixes. Multiple streams are good when the sequences are mostly or completely independent from each other.

The first iteration will support 1 stream (i.e. same as master, vanilla unified KV cache) and n_seq_max streams. The latter means that each sequence id is assigned to a separate stream.

In theory, we could assign multiple sequence ids to the same stream to get a partially-unified KV cache, but this would need extra work and it might not have any useful applications. So out of scope for now.

Should it be a separate cache type instead?

There is too much similar logic. Still thinking about it, but most likely it will end up in the same cache type.

The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding)

Yes.

Comment on lines 73 to 75
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are sequential seq_ids required when virtual sequences are used?

Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?

I wonder if there could be a way to avoid this requirement with ggml_get_rows and/or ggml_mul_mat_id. Might not be worth the extra indirection, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are sequential seq_ids required when virtual sequences are used?

Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?

Yes, we make a view of the KV cache across the streams here:

ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
const uint64_t kv_size = get_size();
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size)*sinfo.s0);
}

The ns var is the number of streams that participate in the current ubatch. Their stream indices range from [s0, s1].

I wonder if there could be a way to avoid this requirement with ggml_get_rows and/or ggml_mul_mat_id. Might not be worth the extra indirection, though.

It should be possible. But I'm not sure if it would be worth - both in performance and in complexity. We can explore though.

@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*(1 + n_seq_virt)*n_layer_cache*ggml_tensor_overhead()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 1 + intended? Why was it added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the per-stream views of the KV cache:

std::vector<ggml_tensor *> k_seq;
std::vector<ggml_tensor *> v_seq;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
k_seq.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
v_seq.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
}

These are used to implement the llama_memory_seq_cp(). This operation is no longer just assigning ids - it performs actual copy of the buffers in memory when we use multiple streams. Using these helper views, the operation is quite simple to implement:

bool is_full = true;
if (p0 > 0 && p0 + 1 < (int) get_size()) {
is_full = false;
}
if (p1 > 0 && p1 + 1 < (int) get_size()) {
is_full = false;
}
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (virt = %d) to %d (virt = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
for (uint32_t il = 0; il < layers.size(); ++il) {
const auto & layer = layers[il];
ggml_backend_tensor_copy(layer.k_seq[s0], layer.k_seq[s1]);
ggml_backend_tensor_copy(layer.v_seq[s0], layer.v_seq[s1]);
// TODO: do we need synchronization here?
}
// TODO: support this:
GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift");
v_cells[s1].reset();
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
if (v_cells[s0].seq_has(i, seq_id_src)) {
v_cells[s1].pos_set(i, v_cells[s0].pos_get(i));
v_cells[s1].seq_add(i, seq_id_dst);
}
}
v_heads[s1] = v_heads[s0];
//for (uint32_t s = 0; s < n_seq_virt; ++s) {
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
//}
}

Though we cannot copy partial sequences when using multiple streams.

Comment on lines 498 to 508
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about decreasing sequence ids? Is the requirement that they are increasing, or that the included seq_ids should be in a contiguous range?

(decreasing sequence ids might not really happen often in practice though)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decreasing would also work - we just need continuous range. We can either add this, if there is an elegant way to search for this. Or we add some batch pre-processing step to move the complexity at a higher level. Or just delegate it to the user by warning when the batch is not arranged optimally.

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from dbcfcaa to 33dcc3c Compare July 4, 2025 07:04
@JohannesGaessler
Copy link
Collaborator

Sorry, I'm currently not collecting that data but MMLU prompts tend to be relatively short. I've been thinking that it would make sense to add a simple server benchmarking tool; in its simplest version it would just be ~100 lines of Python code so maybe I'll quickly throw something together.

@ddh0
Copy link
Contributor

ddh0 commented Jul 12, 2025

4060 Ti 16GB reporting for duty!

attn-streams == false

LLAMA_SET_ROWS=1 llama-batched-bench -t 8 -tb 8 -m /opt/workspace/gguf/Qwen3-4B-Q8_0.gguf -c 65536 -b 1024 -ub 1024 -npp 1024 -ntg 128 -npl 1,2,4,8,16,32 -fa -ngl 999
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.175 5840.81 2.347 54.54 2.522 456.76
1024 128 2 2304 0.360 5681.14 2.554 100.25 2.914 790.64
1024 128 4 4608 0.772 5306.03 2.777 184.35 3.549 1298.32
1024 128 8 9216 1.746 4692.57 3.391 301.95 5.137 1794.02
1024 128 16 18432 4.371 3748.53 4.570 448.15 8.941 2061.58
1024 128 32 36864 14.479 2263.22 9.194 445.52 23.672 1557.27

attn_streams == true

LLAMA_SET_ROWS=1 llama-batched-bench -t 8 -tb 8 -m /opt/workspace/gguf/Qwen3-4B-Q8_0.gguf -c 65536 -b 1024 -ub 1024 -npp 1024 -ntg 128 -npl 1,2,4,8,16,32 -fa -ngl 999 -as
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.175 5834.92 2.345 54.58 2.521 456.99
1024 128 2 2304 0.347 5893.53 2.544 100.65 2.891 796.95
1024 128 4 4608 0.695 5892.13 2.758 185.66 3.453 1334.52
1024 128 8 9216 1.391 5887.79 3.320 308.43 4.711 1956.12
1024 128 16 18432 2.789 5873.76 4.264 480.31 7.053 2613.27
1024 128 32 36864 5.582 5870.73 6.021 680.30 11.602 3177.26

@JohannesGaessler
Copy link
Collaborator

I did some more performance testing with #14668 but I think it will be necessary to use longer prompts in order to get sensitivity to KV cache changes.

@ggerganov
Copy link
Member Author

ggerganov commented Jul 14, 2025

  • Renamed the llama_context_params parameter from attn_streams to kv_unified. By default it is true to keep the usual behavior

  • Renamed the common_params parameter from attn_streams to kv_split. By default it is false and can be enabled with --kv-split or -kvs CLI arg

  • When llama_context_params::kv_unified == false we will now force usage of ggml_set_rows() even if LLAMA_SET_ROWS environment variable is not set

I think these change slightly improves the user experience - they just have to think about "unified" vs "split (i.e. non-unified)" KV cache and don't have to know about "streams".

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok to keep this gated as an advanced option behind the LLAMA_SET_ROWS environment variable, until all the backends implement support for it.

Other than requiring ggml_set_rows, are there any significant downsides to enabling split KV? I think that if there aren't, it should be the default. If so, it may be better to make it the default already to avoid having to change the command line option later.

@ggerganov
Copy link
Member Author

Other than requiring ggml_set_rows, are there any significant downsides to enabling split KV?

I don't think there are. We don't actually have use cases where multiple sequences share large prefixes (neither llama-server (slots are independent), nor llama-cli (single sequence) need it atm). The missing ggml_set_rows is the only concern and at the current state it will cause CPU-fallback when quantizing the cache on backends that don't support it yet (#14661).

I think it is ok to keep this gated as an advanced option behind the LLAMA_SET_ROWS environment variable, until all the backends implement support for it.

Alright. So I will set "kv split" to true, but require the LLAMA_SET_ROWS to be explicitly defined for now. If the environment variable is not defined, we will fallback to "kv unified" (as early as llama_context initialization).

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jul 16, 2025

I did some more performance testing using the new server-bench.py script, I used the command LLAMA_ARG_MODEL=/opt/models/llama_3.2_instruct-1b-q4_k_m.gguf python3 server-bench.py --path_server /home/johannesg/Projects/llama.cpp/build/bin/llama-server --prompt_source rng-1024-40000 in conjunction with 6x RTX 4090 (frequency limited to 1350 MHz). With master I get:

Benchmark duration:                3393.22 s
Request throughput:                0.03 requests/s = 1.77 requests/min
Total prompt length:               1985544 tokens
Average prompt length:             19855.44 tokens
Average prompt latency:            101467.55 ms
Average prompt speed:              195.68 tokens/s
Total generated tokens:            153785
Average generation depth:          20717.69 tokens
Average total generation speed:    45.32 tokens/s
Average generation speed per slot: 1.42 tokens/s / slot

With this PR I get:

Benchmark duration:                989.57 s
Request throughput:                0.10 requests/s = 6.06 requests/min
Total prompt length:               1985544 tokens
Average prompt length:             19855.44 tokens
Average prompt latency:            24507.98 ms
Average prompt speed:              810.16 tokens/s
Total generated tokens:            142166
Average generation depth:          20868.23 tokens
Average total generation speed:    143.66 tokens/s
Average generation speed per slot: 4.49 tokens/s / slot

So there is a ~3x speedup for large contexts that is consistent with the results for llama-batched-bench. Notably there are still discrepancies in the number of generated tokens, I'll investigate whether this is a bug in my script (I think the overall result still stands though). Here are plots of the generated token throughput:

gen_rate gen_rate

The first image is master, the second image is this PR. As before, the throughput with --kv-split is much more consistent.

@ggerganov
Copy link
Member Author

Note: latest commit changes --kv-split to --kv-unified - i.e. the meaning is reversed. Updating OP now.

@ggerganov
Copy link
Member Author

in conjunction with 6x RTX 4090

For this test, does it make any difference between 1x GPU and 6x GPUs? My understanding is that this test will be dominated by generation time. The prompt processing time would be very short, so pipeline parallelism would likely make small to no difference.

@ggerganov
Copy link
Member Author

ggerganov commented Jul 16, 2025

Here are results on M2 Ultra but with shorter prompts:

LLAMA_ARG_FLASH_ATTN=1 LLAMA_SET_ROWS=1 LLAMA_ARG_MODEL=./models/llama-3.2-1b-instruct/ggml-model-q4_k.gguf python3 scripts/server-bench-x.py --path_server ./build/bin/llama-server --prompt_source rng-1024-8192
Benchmark duration:                312.60 s
Request throughput:                0.32 requests/s = 19.19 requests/min
Total prompt length:               440768 tokens
Average prompt length:             4407.68 tokens
Average prompt latency:            6234.62 ms
Average prompt speed:              706.97 tokens/s
Total generated tokens:            131909
Average generation depth:          5062.24 tokens
Average total generation speed:    421.98 tokens/s
Average generation speed per slot: 13.19 tokens/s / slot
gen_rate prompt_time

The plots do look a bit unexpected. Wouldn't expect so much chaotic variations in the gen_rate.png and there are many outliers in the prompt_time.png.

@JohannesGaessler
Copy link
Collaborator

For this test, does it make any difference between 1x GPU and 6x GPUs?

It shouldn't make a meaningful difference for the performance, I just don't have enough VRAM on a single GPU to run the benchmark with a total context size of 1.4M.

outliers in prompt time

Could be when >1 requests happen to be submitted at almost the same time, would maybe make sense to investigate by also plotting the submission/end times of requests. (And since the benchmarking code is pretty new it's also possible that there are just bugs affecting the results.)

@ggerganov ggerganov merged commit 225e7a1 into master Jul 16, 2025
53 of 56 checks passed
@ggerganov ggerganov deleted the gg/llama-high-throughput branch July 16, 2025 13:35
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jul 16, 2025
@rujialiu
Copy link

Is kv cache quant not supported yet? I've just updated to master, set LLAMA_SET_ROWS=1 and run llama-server with my usual command line but got the following error (even with very small context and just say "hi" from webui):

split_equal: sequential split is not supported when there are coupled sequences in the input batch
decode: failed to find a memory slot for batch of size 156
srv  update_slots: failed to find free space in the KV cache, retrying with smaller batch size, i = 0, n_batch = 512, ret = 1

I removed -ctk q8_0 -ctv q8_0 from command line and it worked (though I had to reduce context size to fit VRAM)

@ggerganov
Copy link
Member Author

@rujialiu You can see summary of the supported backends here: #14661. For example, with CUDA, only ggml_set_rows() with F16, BF16 and FP32 are currently supported. Other types are coming in #14712.

Btw, how are you using llama-server - the message about split_equal should not appear with normal usage because we don't couple the input sequences. So it's either a bug, or you are using it some unusual way. Could you provide more details?

@rujialiu
Copy link

Thanks for the info! That message appeared before sending any request:

>llama-server.exe -a Devstral-Small-2507 -m mistralai_Devstral-Small-2507-Q2_K_L.gguf -c 524288 -fa -ngl 99 -ctk q4_0 -ctv q4_0 -nkvo -np 4 --jinja
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4070 SUPER, compute capability 8.9, VMM: yes
build: 0 (unknown) with MSVC 19.44.35207.1 for x64
system info: n_threads = 16, n_threads_batch = 16, total_threads = 32

system_info: n_threads = 16 (n_threads_batch = 16) / 32 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |

main: binding port with default address family
main: HTTP server is listening, hostname: 127.0.0.1, port: 8080, http threads: 31
main: loading model
srv    load_model: loading model 'mistralai_Devstral-Small-2507-Q2_K_L.gguf'
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 4070 SUPER) - 11053 MiB free
llama_model_loader: loaded meta data with 46 key-value pairs and 363 tensors from mistralai_Devstral-Small-2507-Q2_K_L.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Devstral Small 2507
llama_model_loader: - kv   3:                            general.version str              = 2507
llama_model_loader: - kv   4:                           general.basename str              = Devstral
llama_model_loader: - kv   5:                         general.size_label str              = Small
llama_model_loader: - kv   6:                            general.license str              = apache-2.0
llama_model_loader: - kv   7:                   general.base_model.count u32              = 1
llama_model_loader: - kv   8:                  general.base_model.0.name str              = Mistral Small 3.1 24B Instruct 2503
llama_model_loader: - kv   9:               general.base_model.0.version str              = 2503
llama_model_loader: - kv  10:          general.base_model.0.organization str              = Mistralai
llama_model_loader: - kv  11:              general.base_model.0.repo_url str              = https://huggingface.co/mistralai/Mist...
llama_model_loader: - kv  12:                               general.tags arr[str,1]       = ["text2text-generation"]
llama_model_loader: - kv  13:                          general.languages arr[str,24]      = ["en", "fr", "de", "es", "pt", "it", ...
llama_model_loader: - kv  14:                          llama.block_count u32              = 40
llama_model_loader: - kv  15:                       llama.context_length u32              = 131072
llama_model_loader: - kv  16:                     llama.embedding_length u32              = 5120
llama_model_loader: - kv  17:                  llama.feed_forward_length u32              = 32768
llama_model_loader: - kv  18:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv  19:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  20:                       llama.rope.freq_base f32              = 1000000000.000000
llama_model_loader: - kv  21:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  22:                 llama.attention.key_length u32              = 128
llama_model_loader: - kv  23:               llama.attention.value_length u32              = 128
llama_model_loader: - kv  24:                           llama.vocab_size u32              = 131072
llama_model_loader: - kv  25:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  26:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  27:                         tokenizer.ggml.pre str              = tekken
llama_model_loader: - kv  28:                      tokenizer.ggml.tokens arr[str,131072]  = ["<unk>", "<s>", "</s>", "[INST]", "[...
llama_model_loader: - kv  29:                  tokenizer.ggml.token_type arr[i32,131072]  = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  30:                      tokenizer.ggml.merges arr[str,269443]  = ["臓 臓", "臓 t", "e r", "i n", "臓 ?..
llama_model_loader: - kv  31:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  32:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  33:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  34:            tokenizer.ggml.padding_token_id u32              = 11
llama_model_loader: - kv  35:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  36:               tokenizer.ggml.add_sep_token bool             = false
llama_model_loader: - kv  37:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  38:                    tokenizer.chat_template str              = {%- set default_system_message = 'You...
llama_model_loader: - kv  39:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  40:               general.quantization_version u32              = 2
llama_model_loader: - kv  41:                          general.file_type u32              = 10
llama_model_loader: - kv  42:                      quantize.imatrix.file str              = /models_out/Devstral-Small-2507-GGUF/...
llama_model_loader: - kv  43:                   quantize.imatrix.dataset str              = /training_dir/calibration_datav3.txt
llama_model_loader: - kv  44:             quantize.imatrix.entries_count u32              = 280
llama_model_loader: - kv  45:              quantize.imatrix.chunks_count u32              = 499
llama_model_loader: - type  f32:   81 tensors
llama_model_loader: - type q8_0:    2 tensors
llama_model_loader: - type q2_K:  160 tensors
llama_model_loader: - type q3_K:   80 tensors
llama_model_loader: - type q4_K:   40 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q2_K - Medium
print_info: file size   = 8.88 GiB (3.24 BPW)
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 1000
load: token to piece cache size = 0.8498 MB
print_info: arch             = llama
print_info: vocab_only       = 0
print_info: n_ctx_train      = 131072
print_info: n_embd           = 5120
print_info: n_layer          = 40
print_info: n_head           = 32
print_info: n_head_kv        = 8
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 4
print_info: n_embd_k_gqa     = 1024
print_info: n_embd_v_gqa     = 1024
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-05
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 32768
print_info: n_expert         = 0
print_info: n_expert_used    = 0
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = linear
print_info: freq_base_train  = 1000000000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 131072
print_info: rope_finetuned   = unknown
print_info: model type       = 13B
print_info: model params     = 23.57 B
print_info: general.name     = Devstral Small 2507
print_info: vocab type       = BPE
print_info: n_vocab          = 131072
print_info: n_merges         = 269443
print_info: BOS token        = 1 '<s>'
print_info: EOS token        = 2 '</s>'
print_info: UNK token        = 0 '<unk>'
print_info: PAD token        = 11 '<pad>'
print_info: LF token         = 1010 '膴'
print_info: EOG token        = 2 '</s>'
print_info: max token length = 150
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors:        CUDA0 model buffer size =  8415.96 MiB
load_tensors:   CPU_Mapped model buffer size =   680.00 MiB
.......................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 4
llama_context: n_ctx         = 524288
llama_context: n_ctx_per_seq = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: kv_unified    = false
llama_context: freq_base     = 1000000000.0
llama_context: freq_scale    = 1
llama_context:  CUDA_Host  output buffer size =     2.00 MiB
llama_kv_cache_unified:        CPU KV buffer size = 23040.00 MiB
llama_kv_cache_unified: size = 23040.00 MiB (131072 cells,  40 layers,  4/ 4 seqs), K (q4_0): 11520.00 MiB, V (q4_0): 11520.00 MiB
llama_context:      CUDA0 compute buffer size =   996.00 MiB
llama_context:  CUDA_Host compute buffer size =   266.01 MiB
llama_context: graph nodes  = 1247
llama_context: graph splits = 82
common_init_from_params: added </s> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 524288
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
split_equal: sequential split is not supported when there are coupled sequences in the input batch
decode: failed to find a memory slot for batch of size 2
srv          init: initializing slots, n_slots = 4
slot         init: id  0 | task -1 | new slot n_ctx_slot = 131072
slot         init: id  1 | task -1 | new slot n_ctx_slot = 131072
slot         init: id  2 | task -1 | new slot n_ctx_slot = 131072
slot         init: id  3 | task -1 | new slot n_ctx_slot = 131072
main: model loaded
main: chat template, chat_template: {%- set default_system_message = 'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. You can interact with a computer to solve tasks.\\n\\n<ROLE>\\nYour primary role is to assist users by executing commands, modifying code, and solving technical problems effectively. You should be thorough, methodical, and prioritize quality over speed.\\n* If the user asks a question, like \\"why is X happening\\", don\'t try to fix the problem. Just give an answer to the question.\\n</ROLE>\\n\\n<EFFICIENCY>\\n* Each action you take is somewhat expensive. Wherever possible, combine multiple actions into a single action, e.g. combine multiple bash commands into one, using sed and grep to edit/view multiple files at once.\\n* When exploring the codebase, use efficient tools like find, grep, and git commands with appropriate filters to minimize unnecessary operations.\\n</EFFICIENCY>\\n\\n<FILE_SYSTEM_GUIDELINES>\\n* When a user provides a file path, do NOT assume it\'s relative to the current working directory. First explore the file system to locate the file before working on it.\\n* If asked to edit a file, edit the file directly, rather than creating a new file with a different filename.\\n* For global search-and-replace operations, consider using `sed` instead of opening file editors multiple times.\\n</FILE_SYSTEM_GUIDELINES>\\n\\n<CODE_QUALITY>\\n* Write clean, efficient code with minimal comments. Avoid redundancy in comments: Do not repeat information that can be easily inferred from the code itself.\\n* When implementing solutions, focus on making the minimal changes needed to solve the problem.\\n* Before implementing any changes, first thoroughly understand the codebase through exploration.\\n* If you are adding a lot of code to a function or file, consider splitting the function or file into smaller pieces when appropriate.\\n</CODE_QUALITY>\\n\\n<VERSION_CONTROL>\\n* When configuring git credentials, use \\"openhands\\" as the user.name and \\"openhands@all-hands.dev\\" as the user.email by default, unless explicitly instructed otherwise.\\n* Exercise caution with git operations. Do NOT make potentially dangerous changes (e.g., pushing to main, deleting repositories) unless explicitly asked to do so.\\n* When committing changes, use `git status` to see all modified files, and stage all files necessary for the commit. Use `git commit -a` whenever possible.\\n* Do NOT commit files that typically shouldn\'t go into version control (e.g., node_modules/, .env files, build directories, cache files, large binaries) unless explicitly instructed by the user.\\n* If unsure about committing certain files, check for the presence of .gitignore files or ask the user for clarification.\\n</VERSION_CONTROL>\\n\\n<PULL_REQUESTS>\\n* When creating pull requests, create only ONE per session/issue unless explicitly instructed otherwise.\\n* When working with an existing PR, update it with new commits rather than creating additional PRs for the same issue.\\n* When updating a PR, preserve the original PR title and purpose, updating description only when necessary.\\n</PULL_REQUESTS>\\n\\n<PROBLEM_SOLVING_WORKFLOW>\\n1. EXPLORATION: Thoroughly explore relevant files and understand the context before proposing solutions\\n2. ANALYSIS: Consider multiple approaches and select the most promising one\\n3. TESTING:\\n   * For bug fixes: Create tests to verify issues before implementing fixes\\n   * For new features: Consider test-driven development when appropriate\\n   * If the repository lacks testing infrastructure and implementing tests would require extensive setup, consult with the user before investing time in building testing infrastructure\\n   * If the environment is not set up to run tests, consult with the user first before investing time to install all dependencies\\n4. IMPLEMENTATION: Make focused, minimal changes to address the problem\\n5. VERIFICATION: If the environment is set up to run tests, test your implementation thoroughly, including edge cases. If the environment is not set up to run tests, consult with the user first before investing time to run tests.\\n</PROBLEM_SOLVING_WORKFLOW>\\n\\n<SECURITY>\\n* Only use GITHUB_TOKEN and other credentials in ways the user has explicitly requested and would expect.\\n* Use APIs to work with GitHub or other platforms, unless the user asks otherwise or your task requires browsing.\\n</SECURITY>\\n\\n<ENVIRONMENT_SETUP>\\n* When user asks you to run an application, don\'t stop if the application is not installed. Instead, please install the application and run the command again.\\n* If you encounter missing dependencies:\\n  1. First, look around in the repository for existing dependency files (requirements.txt, pyproject.toml, package.json, Gemfile, etc.)\\n  2. If dependency files exist, use them to install all dependencies at once (e.g., `pip install -r requirements.txt`, `npm install`, etc.)\\n  3. Only install individual packages directly if no dependency files are found or if only specific packages are needed\\n* Similarly, if you encounter missing dependencies for essential tools requested by the user, install them when possible.\\n</ENVIRONMENT_SETUP>\\n\\n<TROUBLESHOOTING>\\n* If you\'ve made repeated attempts to solve a problem but tests still fail or the user reports it\'s still broken:\\n  1. Step back and reflect on 5-7 different possible sources of the problem\\n  2. Assess the likelihood of each possible cause\\n  3. Methodically address the most likely causes, starting with the highest probability\\n  4. Document your reasoning process\\n* When you run into any major issue while executing a plan from the user, please don\'t try to directly work around it. Instead, propose a new plan and confirm with the user before proceeding.\\n</TROUBLESHOOTING>\\n' %}
{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
    {%- if messages[0]['content'] is string %}
        {%- set system_message = messages[0]['content'] %}
    {%- else %}
        {%- set system_message = messages[0]['content'][0]['text'] %}
    {%- endif %}
    {%- set loop_messages = messages[1:] %}
{%- else %}
    {%- set system_message = default_system_message %}
    {%- set loop_messages = messages %}
{%- endif %}
    {{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{%- for message in loop_messages %}
    {%- if message['role'] == 'user' %}
        {%- if message['content'] is string %}
            {{- '[INST]' + message['content'] + '[/INST]' }}
        {%- else %}
            {{- '[INST]' }}
        {%- for block in message['content'] %}
            {%- if block['type'] == 'text' %}
                {{- block['text'] }}
            {%- else %}
                {{- raise_exception('Only text is supported in message content!') }}
            {%- endif %}
        {%- endfor %}
        {{- '[/INST]' }}
        {%- endif %}
    {%- elif message['role'] == 'system' %}
        {%- if message['content'] is string %}
            {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
        {%- else %}
            {{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}
    {%- endif %}
    {%- elif message['role'] == 'assistant' %}
        {%- if message['content'] is string %}
            {{- message['content'] + eos_token }}
        {%- else %}
            {{- message['content'][0]['text'] + eos_token }}
        {%- endif %}
    {%- else %}
        {{- raise_exception('Only user, system and assistant roles are supported!') }}
    {%- endif %}
{%- endfor %}, example_format: '[SYSTEM_PROMPT]You are a helpful assistant[/SYSTEM_PROMPT][INST]Hello[/INST]Hi there</s>[INST]How are you?[/INST]'
main: server is listening on http://127.0.0.1:8080 - starting the main loop
srv  update_slots: all slots are idle

BTW: This is a surprisingly usable setup with 12GB VRAM to allow 4 concurrent coding agent with context length 128k. I'm been already using it with cline and able to get some non-trivial jobs done (though should be used with care) :D

@ggerganov
Copy link
Member Author

There was indeed a bug - will be fixed with #14733.

@CISC
Copy link
Collaborator

CISC commented Jul 17, 2025

@ggerganov It looks like mask is not correctly padded with parallel processing:

LLAMA_SET_ROWS=1 ./llama-cli -m LFM2-1.2B-bf16.gguf -t 8 [...] --parallel 2
[...]
llama_kv_cache_unified: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to 512
llama_kv_cache_unified:        CPU KV buffer size =    48,00 MiB
llama_kv_cache_unified: size =   48,00 MiB (  4096 cells,   6 layers,  2/ 1 seqs), K (f16):   24,00 MiB, V (f16):   24,00 MiB
llama_memory_recurrent:        CPU RS buffer size =     0,31 MiB
llama_memory_recurrent: size =    0,31 MiB (     2 cells,  16 layers,  2 seqs), R (f32):    0,31 MiB, S (f32):    0,00 MiB
llama.cpp/ggml/src/ggml.c:3740: GGML_ASSERT(mask->ne[1] >= a->ne[1]) failed

Edit: With -fa it fails in ggml_flash_attn_ext instead of ggml_soft_max_ext:

llama.cpp/ggml/src/ggml.c:4768: GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big") failed

@ggerganov
Copy link
Member Author

@CISC The llama_memory_hybrid constructor has to respect the cparams.kv_unified value. Currently it is hardcoded to 1 (i.e. true). This patch fixes the issue, but a proper fix should be implemented to handle both unified and split modes:

diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp
index d8e2086c8..ab6470bdf 100644
--- a/src/llama-memory-hybrid.cpp
+++ b/src/llama-memory-hybrid.cpp
@@ -31,21 +31,21 @@ llama_memory_hybrid::llama_memory_hybrid(
     hparams(model.hparams),
     mem_attn(new llama_kv_cache_unified(
         model,
         filter_attn == nullptr ?
             [&](int32_t il) { return !hparams.is_recurrent(il); }
             : filter_attn,
         type_k,
         type_v,
         v_trans,
         offload,
-        1,
+        false,
         kv_size,
         n_seq_max,
         n_pad,
         n_swa,
         swa_type
     )),
     mem_recr(new llama_memory_recurrent(
         model,
         filter_recr == nullptr ?
             [&](int32_t il) { return hparams.is_recurrent(il); }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning hot Something that is hot Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants