-
Notifications
You must be signed in to change notification settings - Fork 12.4k
llama : support qwen3 rerank and embeddings #14029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 8 commits
3f3b9a2
f8fd440
e0eb4b8
030dc3b
f8facb3
0777cd3
8edd2cf
c02f53d
c2f4dc7
cbb6f20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,9 +167,15 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { | |
} | ||
|
||
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | ||
if (cparams.embeddings && ( | ||
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || | ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { | ||
if (!cparams.embeddings) { | ||
return; | ||
} | ||
|
||
const bool is_last_tok = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST || | ||
arch == LLM_ARCH_QWEN3; // qwen3 reranking & embedding models use last token | ||
|
||
if (is_last_tok) { | ||
// set output to the last token of each sequence | ||
const int64_t n_tokens = ubatch->n_tokens; | ||
const int64_t n_seq_tokens = ubatch->n_seq_tokens; | ||
const int64_t n_seqs = ubatch->n_seqs; | ||
|
@@ -180,23 +186,33 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | |
uint32_t * data = (uint32_t *) cls->data; | ||
memset(cls->data, 0, n_tokens * ggml_element_size(cls)); | ||
|
||
std::vector<int> last_pos(n_tokens, -1); | ||
std::vector<int> last_row(n_tokens, -1); | ||
|
||
for (int s = 0; s < n_seqs; ++s) { | ||
const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||
|
||
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true | ||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); | ||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); | ||
|
||
for (int i = 0; i < n_seq_tokens; ++i) { | ||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i]; | ||
|
||
if (pos == 0) { | ||
data[seq_id] = s*n_seq_tokens + i; | ||
if (pos >= last_pos[seq_id]) { | ||
last_pos[seq_id] = pos; | ||
last_row[seq_id] = s*n_seq_tokens + i; | ||
} | ||
} | ||
} | ||
} | ||
|
||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { | ||
for (int i = 0; i < n_tokens; ++i) { | ||
if (last_row[i] >= 0) { | ||
data[i] = last_row[i]; | ||
} | ||
} | ||
|
||
} else { | ||
// set output to first token of each sequence | ||
const int64_t n_tokens = ubatch->n_tokens; | ||
const int64_t n_seq_tokens = ubatch->n_seq_tokens; | ||
const int64_t n_seqs = ubatch->n_seqs; | ||
|
@@ -207,30 +223,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | |
uint32_t * data = (uint32_t *) cls->data; | ||
memset(cls->data, 0, n_tokens * ggml_element_size(cls)); | ||
|
||
std::vector<int> last_pos(n_tokens, -1); | ||
std::vector<int> last_row(n_tokens, -1); | ||
|
||
for (int s = 0; s < n_seqs; ++s) { | ||
const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||
|
||
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true | ||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); | ||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); | ||
|
||
for (int i = 0; i < n_seq_tokens; ++i) { | ||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i]; | ||
|
||
if (pos >= last_pos[seq_id]) { | ||
last_pos[seq_id] = pos; | ||
last_row[seq_id] = s*n_seq_tokens + i; | ||
if (pos == 0) { | ||
data[seq_id] = s*n_seq_tokens + i; | ||
} | ||
} | ||
} | ||
|
||
for (int i = 0; i < n_tokens; ++i) { | ||
if (last_row[i] >= 0) { | ||
data[i] = last_row[i]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
@@ -943,7 +949,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const { | |
} | ||
|
||
ggml_tensor * llm_graph_context::build_inp_cls() const { | ||
auto inp = std::make_unique<llm_graph_input_cls>(cparams); | ||
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch); | ||
|
||
auto & cur = inp->cls; | ||
|
||
|
@@ -1577,10 +1583,15 @@ void llm_graph_context::build_pooling( | |
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b); | ||
} | ||
} else if (cls_out) { | ||
// Single layer classification head (direct projection) | ||
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 | ||
GGML_ASSERT(cls_out_b != nullptr); | ||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b); | ||
if (arch == LLM_ARCH_QWEN3) { | ||
cur = ggml_mul_mat(ctx0, cls_out, inp); | ||
cur = ggml_soft_max(ctx0, cur); // qwen3 uses softmax on the output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ggerganov I think there is a bug with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can make quick fix for now like this, similar to diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index afef84870..8b11197df 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -7043,7 +7043,7 @@ struct llm_build_qwen3 : public llm_graph_context {
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
- if (il == n_layer - 1) {
+ if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No that still doesn't work as I expected. For example, if my sequence has only one output token, then I expect the Maybe I misunderstood something here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm ok I think I got it. The main problem is that qwen's rerank model use causal attention, it's simply a normal next generation model which outputs either "yes" or "no" token I think the assumption in llama.cpp is that CLS and RANK are non-causal, hence only the first token is marked as output Not sure what's the best way to support this though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK found a hack around this, for Qwen3, I force the position to last (only the position, not the pooling) in 030dc3b Probably we should separate the notion of "pooling" and "output position" in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The idea is that the For Qwen3 rerank, what you seem to need is to pool using And it seems we should remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm ok I got it. The problem is that I don't have much time for the rest of the day. Do you think we can clean this up in a follow up PR?
I think having the notion of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think it's better to take the time and make it right, no need to merge it now. |
||
} else { | ||
// Single layer classification head (direct projection) | ||
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 | ||
GGML_ASSERT(cls_out_b != nullptr); | ||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b); | ||
} | ||
} else { | ||
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b"); | ||
} | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -825,6 +825,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||||||
case LLM_ARCH_QWEN3: | ||||||||
{ | ||||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); | ||||||||
|
||||||||
switch (hparams.n_layer) { | ||||||||
case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; | ||||||||
case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; | ||||||||
|
@@ -2468,6 +2470,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { | |||||||
{ | ||||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); | ||||||||
|
||||||||
// output rerank | ||||||||
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); | ||||||||
|
||||||||
// output | ||||||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); | ||||||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); | ||||||||
|
@@ -7057,7 +7062,7 @@ struct llm_build_qwen3 : public llm_graph_context { | |||||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); | ||||||||
} | ||||||||
|
||||||||
if (il == n_layer - 1) { | ||||||||
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { | ||||||||
// skip computing output for unused tokens | ||||||||
ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||||||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids); | ||||||||
|
@@ -13788,7 +13793,8 @@ uint64_t llama_model_size(const llama_model * model) { | |||||||
} | ||||||||
|
||||||||
const char * llama_model_chat_template(const llama_model * model, const char * name) { | ||||||||
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) | ||||||||
const auto key = name | ||||||||
? LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + std::string(name) | ||||||||
Comment on lines
+13796
to
+13797
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I wonder how long this has been broken? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, has never worked it seems, broken since it was introduced in #11016 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not in used by any of the examples so we don't know if it works in the first place (probably used in downstream project, but idk) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make a PR. |
||||||||
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); | ||||||||
const auto & it = model->gguf_kv.find(key); | ||||||||
if (it == model->gguf_kv.end()) { | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you can use LLM_KV_TOKENIZER_CHAT_TEMPLATE with suffix:
llama.cpp/src/llama-arch.cpp
Lines 1722 to 1725 in c02f53d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doing this but it doesn't work, maybe it's buggy somewhere else:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking in the wrong place, this is where it's broken:
llama.cpp/src/llama-arch.cpp
Lines 1709 to 1712 in e83ba3e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in #14050