Skip to content

Commit 5f44f4b

Browse files
KawrakowIwan Kawrakow
andauthored
Guard against attempts to use MLA for non-MLA models (ikawrakow#320)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 22d7440 commit 5f44f4b

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/llama.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3245,13 +3245,15 @@ static bool llama_kv_cache_init(
32453245
cache.ctxs.push_back(ctx);
32463246
}
32473247

3248-
cache.k_l.reserve(n_layer);
3249-
cache.v_l.reserve(n_layer);
3250-
3251-
// DeepSeek MLA
3252-
cache.kv_l.reserve(n_layer);
3253-
if (cparams.mla_attn == 1 && !cparams.flash_attn) {
3254-
cache.kvt_l.reserve(n_layer);
3248+
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
3249+
// DeepSeek MLA
3250+
cache.kv_l.reserve(n_layer);
3251+
if (cparams.mla_attn == 1 && !cparams.flash_attn) {
3252+
cache.kvt_l.reserve(n_layer);
3253+
}
3254+
} else {
3255+
cache.k_l.reserve(n_layer);
3256+
cache.v_l.reserve(n_layer);
32553257
}
32563258

32573259
bool warn = true;
@@ -3299,7 +3301,7 @@ static bool llama_kv_cache_init(
32993301
cache.v_l.push_back(v);
33003302
}
33013303
}
3302-
if (cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
3304+
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
33033305
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
33043306
LLAMA_LOG_ERROR("%s: bailing out\n", __func__);
33053307
GGML_ABORT("fatal error");
@@ -18568,6 +18570,13 @@ struct llama_context * llama_new_context_with_model(
1856818570
params.seed = time(NULL);
1856918571
}
1857018572

18573+
if (model->arch != LLM_ARCH_DEEPSEEK2 && cparams.mla_attn > 0) {
18574+
LLAMA_LOG_WARN("=====================================================================\n");
18575+
LLAMA_LOG_WARN(" MLA is only available for LLM_ARCH_DEEPSEEK2 -> turning off MLA\n");
18576+
LLAMA_LOG_WARN("=====================================================================\n");
18577+
cparams.mla_attn = 0;
18578+
}
18579+
1857118580
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
1857218581
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
1857318582
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);

0 commit comments

Comments
 (0)