@@ -956,6 +956,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
956
956
case 46: type = LLM_TYPE_27B; break;
957
957
default: type = LLM_TYPE_UNKNOWN;
958
958
}
959
+
960
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
961
+ hparams.f_attention_scale = type == LLM_TYPE_27B
962
+ ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
963
+ : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
959
964
} break;
960
965
case LLM_ARCH_GEMMA3:
961
966
{
@@ -976,6 +981,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
976
981
default: type = LLM_TYPE_UNKNOWN;
977
982
}
978
983
984
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
979
985
hparams.f_attention_scale = type == LLM_TYPE_27B
980
986
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
981
987
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@@ -8484,14 +8490,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8484
8490
cb(Kcur, "Kcur", il);
8485
8491
cb(Vcur, "Vcur", il);
8486
8492
8487
- // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
8488
- switch (model.type) {
8489
- case LLM_TYPE_2B:
8490
- case LLM_TYPE_9B:
8491
- case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
8492
- default: GGML_ABORT("fatal error");
8493
- };
8494
- cb(Qcur, "Qcur_scaled", il);
8493
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
8495
8494
8496
8495
cur = build_attn(inp_attn, gf,
8497
8496
model.layers[il].wo, NULL,
@@ -8632,9 +8631,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8632
8631
cb(Kcur, "Kcur", il);
8633
8632
cb(Vcur, "Vcur", il);
8634
8633
8634
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
8635
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
8636
+
8635
8637
cur = build_attn(inp_attn, gf,
8636
8638
model.layers[il].wo, NULL,
8637
- Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale , il);
8639
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f , il);
8638
8640
}
8639
8641
8640
8642
cur = build_norm(cur,
0 commit comments