Skip to content

Commit 5582c49

Browse files
authored
gemma : more consistent attention scaling for v2 and v3 (#13951)
* gemma : fix attn scale for 27B * cont : apply scale before attn * cont : consistent attention scaling
1 parent c9bbc77 commit 5582c49

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/llama-model.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
956956
case 46: type = LLM_TYPE_27B; break;
957957
default: type = LLM_TYPE_UNKNOWN;
958958
}
959+
960+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
961+
hparams.f_attention_scale = type == LLM_TYPE_27B
962+
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
963+
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
959964
} break;
960965
case LLM_ARCH_GEMMA3:
961966
{
@@ -976,6 +981,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
976981
default: type = LLM_TYPE_UNKNOWN;
977982
}
978983

984+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
979985
hparams.f_attention_scale = type == LLM_TYPE_27B
980986
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
981987
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@@ -8484,14 +8490,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
84848490
cb(Kcur, "Kcur", il);
84858491
cb(Vcur, "Vcur", il);
84868492

8487-
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
8488-
switch (model.type) {
8489-
case LLM_TYPE_2B:
8490-
case LLM_TYPE_9B:
8491-
case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
8492-
default: GGML_ABORT("fatal error");
8493-
};
8494-
cb(Qcur, "Qcur_scaled", il);
8493+
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
84958494

84968495
cur = build_attn(inp_attn, gf,
84978496
model.layers[il].wo, NULL,
@@ -8632,9 +8631,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
86328631
cb(Kcur, "Kcur", il);
86338632
cb(Vcur, "Vcur", il);
86348633

8634+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
8635+
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
8636+
86358637
cur = build_attn(inp_attn, gf,
86368638
model.layers[il].wo, NULL,
8637-
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
8639+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
86388640
}
86398641

86408642
cur = build_norm(cur,

0 commit comments

Comments
 (0)