Skip to content

Commit faba70f

Browse files
iamlemecNeoZhangJianyu
authored andcommitted
llama : allow pooled embeddings on any model (ggml-org#7477)
* create append_pooling operation; allow to specify attention_type; add last token pooling; update examples * find result_norm/result_embd tensors properly; update output allocation logic * only use embd output for pooling_type NONE * get rid of old causal_attn accessor * take out attention_type; add in llama_set_embeddings * bypass logits when doing non-NONE pooling
1 parent 2554c4f commit faba70f

File tree

6 files changed

+130
-70
lines changed

6 files changed

+130
-70
lines changed

common/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
541541
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
542542
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
543543
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
544+
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
544545
else { invalid_param = true; }
545546
return true;
546547
}
@@ -1869,6 +1870,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18691870

18701871
options.push_back({ "backend" });
18711872
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
1873+
18721874
if (llama_supports_mlock()) {
18731875
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
18741876
}

examples/embedding/embedding.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
1717
return lines;
1818
}
1919

20-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
21-
for (size_t i = 0; i < tokens.size(); i++) {
22-
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
20+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
21+
size_t n_tokens = tokens.size();
22+
for (size_t i = 0; i < n_tokens; i++) {
23+
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
2324
}
2425
}
2526

@@ -40,13 +41,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4041

4142
// try to get sequence embeddings - supported only when pooling_type is not NONE
4243
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
43-
if (embd == NULL) {
44-
embd = llama_get_embeddings_ith(ctx, i);
45-
if (embd == NULL) {
46-
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
47-
continue;
48-
}
49-
}
44+
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
5045

5146
float * out = output + batch.seq_id[i][0] * n_embd;
5247
//TODO: I would also add a parameter here to enable normalization or not.
@@ -97,6 +92,12 @@ int main(int argc, char ** argv) {
9792
const int n_ctx_train = llama_n_ctx_train(model);
9893
const int n_ctx = llama_n_ctx(ctx);
9994

95+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
96+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
97+
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
98+
return 1;
99+
}
100+
100101
if (n_ctx > n_ctx_train) {
101102
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
102103
__func__, n_ctx_train, n_ctx);

examples/gritlm/gritlm.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4444

4545
// clear previous kv_cache values (irrelevant for embeddings)
4646
llama_kv_cache_clear(ctx);
47+
llama_set_embeddings(ctx, true);
4748
llama_set_causal_attn(ctx, false);
4849

4950
// run model
@@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
9899
llama_token eos_token = llama_token_eos(mdl);
99100

100101
llama_kv_cache_clear(ctx);
102+
llama_set_embeddings(ctx, false);
101103
llama_set_causal_attn(ctx, true);
104+
102105
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
103106

104107
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@@ -166,8 +169,7 @@ int main(int argc, char * argv[]) {
166169

167170
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
168171

169-
// create new context - set to embedding mode
170-
cparams.embeddings = true;
172+
// create generation context
171173
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
172174

173175
// ### Embedding/Representation ###

examples/retrieval/retrieval.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
7373
return chunks;
7474
}
7575

76-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
77-
for (size_t i = 0; i < tokens.size(); i++) {
78-
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
76+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
77+
size_t n_tokens = tokens.size();
78+
for (size_t i = 0; i < n_tokens; i++) {
79+
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
7980
}
8081
}
8182

@@ -160,6 +161,12 @@ int main(int argc, char ** argv) {
160161
const int n_ctx_train = llama_n_ctx_train(model);
161162
const int n_ctx = llama_n_ctx(ctx);
162163

164+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
165+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
166+
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
167+
return 1;
168+
}
169+
163170
if (n_ctx > n_ctx_train) {
164171
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
165172
__func__, n_ctx_train, n_ctx);

llama.cpp

Lines changed: 98 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7650,6 +7650,50 @@ struct llm_build_context {
76507650
return lctx.inp_s_seq;
76517651
}
76527652

7653+
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
7654+
// find result_norm tensor for input
7655+
struct ggml_tensor * inp = nullptr;
7656+
for (int i = gf->n_nodes - 1; i >= 0; --i) {
7657+
inp = gf->nodes[i];
7658+
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
7659+
break;
7660+
} else {
7661+
inp = nullptr;
7662+
}
7663+
}
7664+
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
7665+
7666+
struct ggml_tensor * cur;
7667+
7668+
switch (pooling_type) {
7669+
case LLAMA_POOLING_TYPE_MEAN:
7670+
{
7671+
struct ggml_tensor * inp_mean = build_inp_mean();
7672+
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
7673+
} break;
7674+
case LLAMA_POOLING_TYPE_CLS:
7675+
case LLAMA_POOLING_TYPE_LAST:
7676+
{
7677+
struct ggml_tensor * inp_cls = build_inp_cls();
7678+
cur = ggml_get_rows(ctx0, inp, inp_cls);
7679+
} break;
7680+
case LLAMA_POOLING_TYPE_NONE:
7681+
{
7682+
cur = inp;
7683+
} break;
7684+
default:
7685+
{
7686+
GGML_ASSERT(false && "unknown pooling type");
7687+
} break;
7688+
}
7689+
7690+
cb(cur, "result_embd_pooled", -1);
7691+
7692+
ggml_build_forward_expand(gf, cur);
7693+
7694+
return gf;
7695+
}
7696+
76537697
struct ggml_cgraph * build_llama() {
76547698
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
76557699

@@ -8630,8 +8674,6 @@ struct llm_build_context {
86308674
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
86318675
inp_pos = build_inp_pos();
86328676
}
8633-
struct ggml_tensor * inp_mean = build_inp_mean();
8634-
struct ggml_tensor * inp_cls = build_inp_cls();
86358677

86368678
// construct input embeddings (token, type, position)
86378679
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@@ -8806,28 +8848,6 @@ struct llm_build_context {
88068848
cur = inpL;
88078849
cb(cur, "result_embd", -1);
88088850

8809-
// pooling layer
8810-
switch (pooling_type) {
8811-
case LLAMA_POOLING_TYPE_NONE:
8812-
{
8813-
// nop
8814-
} break;
8815-
case LLAMA_POOLING_TYPE_MEAN:
8816-
{
8817-
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
8818-
cb(cur, "result_embd_pooled", -1);
8819-
} break;
8820-
case LLAMA_POOLING_TYPE_CLS:
8821-
{
8822-
cur = ggml_get_rows(ctx0, cur, inp_cls);
8823-
cb(cur, "result_embd_pooled", -1);
8824-
} break;
8825-
case LLAMA_POOLING_TYPE_UNSPECIFIED:
8826-
{
8827-
GGML_ASSERT(false && "Invalid pooling type");
8828-
} break;
8829-
}
8830-
88318851
ggml_build_forward_expand(gf, cur);
88328852

88338853
return gf;
@@ -11912,6 +11932,11 @@ static struct ggml_cgraph * llama_build_graph(
1191211932
GGML_ASSERT(false);
1191311933
}
1191411934

11935+
// add on pooling layer
11936+
if (lctx.cparams.embeddings) {
11937+
result = llm.append_pooling(result);
11938+
}
11939+
1191511940
llm.free();
1191611941

1191711942
return result;
@@ -12001,7 +12026,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1200112026
// (!a || b) is a logical implication (a -> b)
1200212027
// !hparams.causal_attn -> !cparams.causal_attn
1200312028
(hparams.causal_attn || !cparams.causal_attn) &&
12004-
"causal attention with embedding models is not supported"
12029+
"causal attention is not supported by this model"
1200512030
);
1200612031

1200712032
if (lctx.inp_KQ_mask) {
@@ -12133,6 +12158,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1213312158
}
1213412159
}
1213512160

12161+
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
12162+
const int64_t n_tokens = batch.n_tokens;
12163+
12164+
GGML_ASSERT(lctx.inp_cls);
12165+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
12166+
12167+
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
12168+
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
12169+
12170+
std::vector<int> last_pos(n_tokens, -1);
12171+
std::vector<int> last_row(n_tokens, -1);
12172+
12173+
for (int i = 0; i < n_tokens; ++i) {
12174+
const llama_seq_id seq_id = batch.seq_id[i][0];
12175+
const llama_pos pos = batch.pos[i];
12176+
12177+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
12178+
12179+
if (pos >= last_pos[seq_id]) {
12180+
last_pos[seq_id] = pos;
12181+
last_row[seq_id] = i;
12182+
}
12183+
}
12184+
12185+
for (int i = 0; i < n_tokens; ++i) {
12186+
if (last_row[i] >= 0) {
12187+
data[i] = last_row[i];
12188+
}
12189+
}
12190+
}
12191+
1213612192
if (kv_self.recurrent) {
1213712193
const int64_t n_kv = kv_self.n;
1213812194

@@ -12194,8 +12250,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
1219412250
const auto n_embd = hparams.n_embd;
1219512251

1219612252
// TODO: use a per-batch flag for logits presence instead
12197-
const bool has_logits = cparams.causal_attn;
12198-
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
12253+
const bool has_logits = !cparams.embeddings;
12254+
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1219912255

1220012256
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1220112257
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
@@ -12325,11 +12381,13 @@ static int llama_decode_internal(
1232512381
std::vector<std::vector<llama_seq_id>> seq_id;
1232612382

1232712383
// count outputs
12328-
if (batch_all.logits) {
12384+
if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
12385+
n_outputs = n_tokens_all;
12386+
} else if (batch_all.logits) {
1232912387
for (uint32_t i = 0; i < n_tokens_all; ++i) {
1233012388
n_outputs += batch_all.logits[i] != 0;
1233112389
}
12332-
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
12390+
} else if (lctx.logits_all) {
1233312391
n_outputs = n_tokens_all;
1233412392
} else {
1233512393
// keep last output only
@@ -12460,30 +12518,13 @@ static int llama_decode_internal(
1246012518
// no output
1246112519
res = nullptr;
1246212520
embd = nullptr;
12463-
} else if (!hparams.causal_attn) {
12464-
res = nullptr; // do not extract logits for embedding models such as BERT
12465-
12466-
// token or sequence embeddings
12467-
embd = gf->nodes[gf->n_nodes - 1];
12468-
12469-
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
1247012521
} else if (cparams.embeddings) {
12471-
// the embeddings could be in the second to last tensor, or any of the previous tensors
12472-
int i_embd = gf->n_nodes - 2;
12473-
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
12474-
i_embd = gf->n_nodes - i;
12475-
if (i_embd < 0) { break; }
12476-
embd = gf->nodes[i_embd];
12477-
}
12478-
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
12479-
12480-
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
12481-
if (!cparams.causal_attn) {
12482-
res = nullptr; // do not extract logits when not needed
12483-
// skip computing logits
12484-
// TODO: is this safe?
12485-
gf->n_nodes = i_embd + 1;
12522+
res = nullptr; // do not extract logits for embedding case
12523+
embd = gf->nodes[gf->n_nodes - 1];
12524+
if (strcmp(embd->name, "result_embd_pooled") != 0) {
12525+
embd = gf->nodes[gf->n_nodes - 2];
1248612526
}
12527+
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
1248712528
} else {
1248812529
embd = nullptr; // do not extract embeddings when not needed
1248912530
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
@@ -12552,11 +12593,10 @@ static int llama_decode_internal(
1255212593
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
1255312594
}
1255412595
} break;
12555-
case LLAMA_POOLING_TYPE_CLS:
1255612596
case LLAMA_POOLING_TYPE_MEAN:
12597+
case LLAMA_POOLING_TYPE_CLS:
12598+
case LLAMA_POOLING_TYPE_LAST:
1255712599
{
12558-
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
12559-
1256012600
// extract sequence embeddings
1256112601
auto & embd_seq_out = lctx.embd_seq;
1256212602
embd_seq_out.clear();
@@ -18112,6 +18152,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
1811218152
ctx->abort_callback_data = abort_callback_data;
1811318153
}
1811418154

18155+
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
18156+
ctx->cparams.embeddings = embeddings;
18157+
}
18158+
1811518159
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
1811618160
ctx->cparams.causal_attn = causal_attn;
1811718161
}

llama.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ extern "C" {
174174
LLAMA_POOLING_TYPE_NONE = 0,
175175
LLAMA_POOLING_TYPE_MEAN = 1,
176176
LLAMA_POOLING_TYPE_CLS = 2,
177+
LLAMA_POOLING_TYPE_LAST = 3,
177178
};
178179

179180
enum llama_split_mode {
@@ -293,7 +294,6 @@ extern "C" {
293294

294295
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
295296
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
296-
// (ignored if no pooling layer)
297297

298298
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
299299
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -786,6 +786,10 @@ extern "C" {
786786
// Get the number of threads used for prompt and batch processing (multiple token).
787787
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
788788

789+
// Set whether the model is in embeddings model or not
790+
// If true, embeddings will be returned but logits will not
791+
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
792+
789793
// Set whether to use causal attention or not
790794
// If set to true, the model will only attend to the past tokens
791795
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

0 commit comments

Comments
 (0)