Skip to content

Commit 7471694

Browse files
committed
find result_norm/result_embd tensors properly; update output allocation logic
1 parent 1b09d09 commit 7471694

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

examples/embedding/embedding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tok
3131
}
3232
}
3333

34-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
35-
int n_tokens = tokens.size();
34+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
35+
size_t n_tokens = tokens.size();
3636
for (size_t i = 0; i < n_tokens; i++) {
3737
bool logit = needs_logit(pooling_type, i, n_tokens);
3838
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);

examples/retrieval/retrieval.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tok
147147
}
148148
}
149149

150-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
151-
int n_tokens = tokens.size();
152-
for (size_t i = 0; i < tokens.size(); i++) {
150+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
151+
size_t n_tokens = tokens.size();
152+
for (size_t i = 0; i < n_tokens; i++) {
153153
bool logit = needs_logit(pooling_type, i, n_tokens);
154154
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
155155
}

llama.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7042,11 +7042,17 @@ struct llm_build_context {
70427042
}
70437043

70447044
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
7045-
struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1];
7046-
if (strcmp(inp->name, "result_embd") != 0) {
7047-
inp = gf->nodes[gf->n_nodes - 2];
7048-
GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found");
7045+
// find result_norm tensor for input
7046+
struct ggml_tensor * inp = nullptr;
7047+
for (int i = gf->n_nodes - 1; i >= 0; --i) {
7048+
inp = gf->nodes[i];
7049+
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
7050+
break;
7051+
} else {
7052+
inp = nullptr;
7053+
}
70497054
}
7055+
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
70507056

70517057
struct ggml_tensor * cur;
70527058

@@ -11111,8 +11117,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
1111111117
const auto n_embd = hparams.n_embd;
1111211118

1111311119
// TODO: use a per-batch flag for logits presence instead
11114-
const bool has_logits = cparams.causal_attn;
11115-
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
11120+
const bool has_logits = !cparams.embeddings;
11121+
const bool has_embd = cparams.embeddings;
1111611122

1111711123
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1111811124
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;

0 commit comments

Comments
 (0)