Skip to content

Commit 1885859

Browse files
compiladeNeoZhangJianyu
authored andcommitted
llama : fix quantization of shared token_embd (ggml-org#5944)
1 parent 6932e1a commit 1885859

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

llama.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10983,6 +10983,9 @@ struct quantize_state_internal {
1098310983

1098410984
bool has_imatrix = false;
1098510985

10986+
// used to figure out if a model shares tok_embd with the output weight
10987+
bool has_output = false;
10988+
1098610989
quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
1098710990
: model(model)
1098810991
, params(params)
@@ -11080,8 +11083,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
1108011083

1108111084
// for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
1108211085
// with the quantization of the output tensor
11083-
if (name == tn(LLM_TENSOR_OUTPUT, "weight") ||
11084-
(LLM_TENSOR_NAMES.at(arch).find(LLM_TENSOR_OUTPUT) == LLM_TENSOR_NAMES.at(arch).end() && name == "token_embd.weight")) {
11086+
if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
1108511087
int nx = tensor->ne[0];
1108611088
if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
1108711089
new_type = GGML_TYPE_Q8_0;
@@ -11470,6 +11472,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
1147011472
else if (name.find("ffn_up") != std::string::npos) {
1147111473
++qs.n_ffn_up;
1147211474
}
11475+
else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
11476+
qs.has_output = true;
11477+
}
1147311478
}
1147411479
if (qs.n_attention_wv != qs.n_ffn_down || (uint32_t)qs.n_attention_wv != model.hparams.n_layer) {
1147511480
LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_ffn_down = %d, hparams.n_layer = %d\n",

0 commit comments

Comments
 (0)