Skip to content

Commit 0105714

Browse files
committed
create append_pooling operation; allow to specify attention_type; add last token pooling; update examples
1 parent f8ec887 commit 0105714

File tree

7 files changed

+175
-74
lines changed

7 files changed

+175
-74
lines changed

common/common.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
542542
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
543543
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
544544
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
545+
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
546+
else { invalid_param = true; }
547+
return true;
548+
}
549+
if (arg == "--attention") {
550+
if (++i >= argc) {
551+
invalid_param = true;
552+
return true;
553+
}
554+
std::string value(argv[i]);
555+
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
556+
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
545557
else { invalid_param = true; }
546558
return true;
547559
}
@@ -1820,6 +1832,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18201832

18211833
options.push_back({ "backend" });
18221834
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
1835+
18231836
if (llama_supports_mlock()) {
18241837
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
18251838
}
@@ -2447,6 +2460,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
24472460
cparams.yarn_beta_slow = params.yarn_beta_slow;
24482461
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
24492462
cparams.pooling_type = params.pooling_type;
2463+
cparams.attention_type = params.attention_type;
24502464
cparams.defrag_thold = params.defrag_thold;
24512465
cparams.cb_eval = params.cb_eval;
24522466
cparams.cb_eval_user_data = params.cb_eval_user_data;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ struct gpt_params {
9494
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
9595
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
9696
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
97+
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type
9798

9899
// // sampling parameters
99100
struct llama_sampling_params sparams;

examples/embedding/embedding.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,25 @@ static std::vector<std::string> split_lines(const std::string & s) {
1717
return lines;
1818
}
1919

20-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
21-
for (size_t i = 0; i < tokens.size(); i++) {
22-
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
20+
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
21+
switch (pooling_type) {
22+
case LLAMA_POOLING_TYPE_MEAN:
23+
case LLAMA_POOLING_TYPE_NONE:
24+
return true;
25+
case LLAMA_POOLING_TYPE_CLS:
26+
return pos == 0;
27+
case LLAMA_POOLING_TYPE_LAST:
28+
return pos == n_tokens - 1;
29+
default:
30+
GGML_ASSERT(false && "unsupported pooling type");
31+
}
32+
}
33+
34+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
35+
int n_tokens = tokens.size();
36+
for (size_t i = 0; i < n_tokens; i++) {
37+
bool logit = needs_logit(pooling_type, i, n_tokens);
38+
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
2339
}
2440
}
2541

@@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4056

4157
// try to get sequence embeddings - supported only when pooling_type is not NONE
4258
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
43-
if (embd == NULL) {
44-
embd = llama_get_embeddings_ith(ctx, i);
45-
if (embd == NULL) {
46-
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
47-
continue;
48-
}
49-
}
59+
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
5060

5161
float * out = output + batch.seq_id[i][0] * n_embd;
5262
//TODO: I would also add a parameter here to enable normalization or not.
@@ -97,6 +107,12 @@ int main(int argc, char ** argv) {
97107
const int n_ctx_train = llama_n_ctx_train(model);
98108
const int n_ctx = llama_n_ctx(ctx);
99109

110+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
111+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
112+
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
113+
return 1;
114+
}
115+
100116
if (n_ctx > n_ctx_train) {
101117
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
102118
__func__, n_ctx_train, n_ctx);
@@ -176,7 +192,7 @@ int main(int argc, char ** argv) {
176192
}
177193

178194
// add to batch
179-
batch_add_seq(batch, inp, s);
195+
batch_add_seq(batch, inp, s, pooling_type);
180196
s += 1;
181197
}
182198

examples/gritlm/gritlm.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4444

4545
// clear previous kv_cache values (irrelevant for embeddings)
4646
llama_kv_cache_clear(ctx);
47-
llama_set_causal_attn(ctx, false);
4847

4948
// run model
5049
llama_decode(ctx, batch);
@@ -98,7 +97,6 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
9897
llama_token eos_token = llama_token_eos(mdl);
9998

10099
llama_kv_cache_clear(ctx);
101-
llama_set_causal_attn(ctx, true);
102100
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
103101

104102
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@@ -166,9 +164,14 @@ int main(int argc, char * argv[]) {
166164

167165
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
168166

169-
// create new context - set to embedding mode
167+
// create generation context
168+
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);
169+
170+
// create embedding context
170171
cparams.embeddings = true;
171-
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
172+
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
173+
cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL;
174+
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
172175

173176
// ### Embedding/Representation ###
174177
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -186,8 +189,8 @@ int main(int argc, char * argv[]) {
186189
};
187190

188191
// No need to add instruction for retrieval documents
189-
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
190-
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
192+
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
193+
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));
191194

192195
const int n_embd = llama_n_embd(mdl);
193196

@@ -206,10 +209,11 @@ int main(int argc, char * argv[]) {
206209
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
207210
{
208211
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
209-
std::string response = generate(ctx, prompt, true);
212+
std::string response = generate(ctx_gen, prompt, true);
210213
}
211214

212-
llama_free(ctx);
215+
llama_free(ctx_gen);
216+
llama_free(ctx_emb);
213217
llama_free_model(mdl);
214218
llama_backend_free();
215219

examples/retrieval/retrieval.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,25 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
7373
return chunks;
7474
}
7575

76-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
76+
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
77+
switch (pooling_type) {
78+
case LLAMA_POOLING_TYPE_MEAN:
79+
case LLAMA_POOLING_TYPE_NONE:
80+
return true;
81+
case LLAMA_POOLING_TYPE_CLS:
82+
return pos == 0;
83+
case LLAMA_POOLING_TYPE_LAST:
84+
return pos == n_tokens - 1;
85+
default:
86+
GGML_ASSERT(false && "unsupported pooling type");
87+
}
88+
}
89+
90+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
91+
int n_tokens = tokens.size();
7792
for (size_t i = 0; i < tokens.size(); i++) {
78-
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
93+
bool logit = needs_logit(pooling_type, i, n_tokens);
94+
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
7995
}
8096
}
8197

@@ -159,6 +175,7 @@ int main(int argc, char ** argv) {
159175

160176
const int n_ctx_train = llama_n_ctx_train(model);
161177
const int n_ctx = llama_n_ctx(ctx);
178+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
162179

163180
if (n_ctx > n_ctx_train) {
164181
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
@@ -230,7 +247,7 @@ int main(int argc, char ** argv) {
230247
}
231248

232249
// add to batch
233-
batch_add_seq(batch, inp, s);
250+
batch_add_seq(batch, inp, s, pooling_type);
234251
s += 1;
235252
}
236253

@@ -253,7 +270,7 @@ int main(int argc, char ** argv) {
253270
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
254271

255272
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
256-
batch_add_seq(query_batch, query_tokens, 0);
273+
batch_add_seq(query_batch, query_tokens, 0, pooling_type);
257274

258275
std::vector<float> query_emb(n_embd, 0);
259276
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);

0 commit comments

Comments
 (0)