Skip to content

Commit 66a6707

Browse files
committed
only use embd output for pooling_type NONE
1 parent 7471694 commit 66a6707

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

llama.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10861,7 +10861,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1086110861
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
1086210862
}
1086310863

10864-
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
10864+
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
1086510865
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
1086610866
const int64_t n_tokens = batch.n_tokens;
1086710867

@@ -10893,7 +10893,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1089310893
// (!a || b) is a logical implication (a -> b)
1089410894
// !hparams.causal_attn -> !cparams.causal_attn
1089510895
(hparams.causal_attn || !cparams.causal_attn) &&
10896-
"causal attention with embedding models is not supported"
10896+
"causal attention is not supported by this model"
1089710897
);
1089810898

1089910899
if (lctx.inp_KQ_mask) {
@@ -11118,7 +11118,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
1111811118

1111911119
// TODO: use a per-batch flag for logits presence instead
1112011120
const bool has_logits = !cparams.embeddings;
11121-
const bool has_embd = cparams.embeddings;
11121+
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1112211122

1112311123
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1112411124
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;

0 commit comments

Comments
 (0)