@@ -2805,6 +2805,11 @@ static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
2805
2805
cache.do_defrag = true;
2806
2806
}
2807
2807
2808
+ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
2809
+ // the FA kernels require padding to avoid extra runtime boundary checks
2810
+ return cparams.flash_attn ? 256u : 32u;
2811
+ }
2812
+
2808
2813
//
2809
2814
// model loading and saving
2810
2815
//
@@ -11513,7 +11518,8 @@ static int llama_decode_internal(
11513
11518
// a heuristic, to avoid attending the full cache if it is not yet utilized
11514
11519
// after enough generations, the benefit from this heuristic disappears
11515
11520
// if we start defragmenting the cache, the benefit from this will be more important
11516
- kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
11521
+ const uint32_t pad = llama_kv_cache_get_padding(cparams);
11522
+ kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
11517
11523
//kv_self.n = llama_kv_cache_cell_max(kv_self);
11518
11524
}
11519
11525
}
@@ -15520,6 +15526,11 @@ struct llama_context * llama_new_context_with_model(
15520
15526
return nullptr;
15521
15527
}
15522
15528
15529
+ if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
15530
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15531
+ params.flash_attn = false;
15532
+ }
15533
+
15523
15534
llama_context * ctx = new llama_context(*model);
15524
15535
15525
15536
const auto & hparams = model->hparams;
@@ -15543,7 +15554,7 @@ struct llama_context * llama_new_context_with_model(
15543
15554
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
15544
15555
15545
15556
// this is necessary due to kv_self.n being padded later during inference
15546
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256 );
15557
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams) );
15547
15558
15548
15559
// with causal attention, the batch size is limited by the context size
15549
15560
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -15588,11 +15599,6 @@ struct llama_context * llama_new_context_with_model(
15588
15599
}
15589
15600
}
15590
15601
15591
- if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
15592
- LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15593
- cparams.flash_attn = false;
15594
- }
15595
-
15596
15602
if (params.seed == LLAMA_DEFAULT_SEED) {
15597
15603
params.seed = time(NULL);
15598
15604
}
0 commit comments