@@ -10861,7 +10861,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
10861
10861
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
10862
10862
}
10863
10863
10864
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
10864
+ if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
10865
10865
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
10866
10866
const int64_t n_tokens = batch.n_tokens;
10867
10867
@@ -10893,7 +10893,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
10893
10893
// (!a || b) is a logical implication (a -> b)
10894
10894
// !hparams.causal_attn -> !cparams.causal_attn
10895
10895
(hparams.causal_attn || !cparams.causal_attn) &&
10896
- "causal attention with embedding models is not supported"
10896
+ "causal attention is not supported by this model "
10897
10897
);
10898
10898
10899
10899
if (lctx.inp_KQ_mask) {
@@ -11118,7 +11118,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
11118
11118
11119
11119
// TODO: use a per-batch flag for logits presence instead
11120
11120
const bool has_logits = !cparams.embeddings;
11121
- const bool has_embd = cparams.embeddings;
11121
+ const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) ;
11122
11122
11123
11123
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
11124
11124
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
0 commit comments