Skip to content

Commit 7d85ea8

Browse files
ggerganovteleprint-me
authored andcommitted
llama : less KV padding when FA is off (ggml-org#7257)
ggml-ci
1 parent d8b6869 commit 7d85ea8

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

llama.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,11 @@ static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
28052805
cache.do_defrag = true;
28062806
}
28072807

2808+
static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
2809+
// the FA kernels require padding to avoid extra runtime boundary checks
2810+
return cparams.flash_attn ? 256u : 32u;
2811+
}
2812+
28082813
//
28092814
// model loading and saving
28102815
//
@@ -11513,7 +11518,8 @@ static int llama_decode_internal(
1151311518
// a heuristic, to avoid attending the full cache if it is not yet utilized
1151411519
// after enough generations, the benefit from this heuristic disappears
1151511520
// if we start defragmenting the cache, the benefit from this will be more important
11516-
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
11521+
const uint32_t pad = llama_kv_cache_get_padding(cparams);
11522+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
1151711523
//kv_self.n = llama_kv_cache_cell_max(kv_self);
1151811524
}
1151911525
}
@@ -15520,6 +15526,11 @@ struct llama_context * llama_new_context_with_model(
1552015526
return nullptr;
1552115527
}
1552215528

15529+
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
15530+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15531+
params.flash_attn = false;
15532+
}
15533+
1552315534
llama_context * ctx = new llama_context(*model);
1552415535

1552515536
const auto & hparams = model->hparams;
@@ -15543,7 +15554,7 @@ struct llama_context * llama_new_context_with_model(
1554315554
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
1554415555

1554515556
// this is necessary due to kv_self.n being padded later during inference
15546-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
15557+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
1554715558

1554815559
// with causal attention, the batch size is limited by the context size
1554915560
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -15588,11 +15599,6 @@ struct llama_context * llama_new_context_with_model(
1558815599
}
1558915600
}
1559015601

15591-
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
15592-
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15593-
cparams.flash_attn = false;
15594-
}
15595-
1559615602
if (params.seed == LLAMA_DEFAULT_SEED) {
1559715603
params.seed = time(NULL);
1559815604
}

0 commit comments

Comments
 (0)