Skip to content

Commit 201a190

Browse files
committed
wip [no ci]
1 parent 7f87172 commit 201a190

File tree

7 files changed

+142
-130
lines changed

7 files changed

+142
-130
lines changed

include/llama.h

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ extern "C" {
5555
// TODO: show sample usage
5656
//
5757

58+
// struct llama_vocab; // TODO: add in the future
5859
struct llama_model;
5960
struct llama_context;
6061

@@ -423,24 +424,23 @@ extern "C" {
423424
LLAMA_API bool llama_supports_mlock (void);
424425
LLAMA_API bool llama_supports_gpu_offload(void);
425426

426-
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
427-
LLAMA_API struct llama_sampling * llama_get_sampling( struct llama_context * ctx);
428-
429427
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
430428
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
431429
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
432430
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
433431

434-
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
435-
436-
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
437-
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
438-
439432
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
440433
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
441434
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
442435
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
443436

437+
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
438+
LLAMA_API struct llama_sampling * llama_get_sampling( struct llama_context * ctx);
439+
440+
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
441+
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
442+
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
443+
444444
// Get the model's RoPE frequency scaling factor
445445
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
446446

@@ -967,36 +967,16 @@ extern "C" {
967967
//
968968

969969
// TODO: args become llama_sampling_params
970-
LLAMA_API struct llama_sampling * llama_sampling_init(int32_t n_vocab, const char * grammar_str, const char * grammar_root);
970+
// TODO: llama_model should become llama_vocab
971+
LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root);
971972

972973
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
973974

974-
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_grammar * grammar);
975+
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
975976

976977
// Sets the current rng seed.
977978
LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
978979

979-
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
980-
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
981-
LLAMA_API void llama_sampling_repetition_penalties(
982-
struct llama_sampling * smpl,
983-
llama_token_data_array * candidates,
984-
const llama_token * last_tokens,
985-
size_t penalty_last_n,
986-
float penalty_repeat,
987-
float penalty_freq,
988-
float penalty_present);
989-
990-
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
991-
/// @param logits Logits extracted from the original generation context.
992-
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
993-
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
994-
LLAMA_API void llama_sampling_apply_guidance(
995-
struct llama_sampling * smpl,
996-
float * logits,
997-
float * logits_guidance,
998-
float scale);
999-
1000980
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1001981
LLAMA_API void llama_sampling_softmax(
1002982
struct llama_sampling * smpl,
@@ -1050,6 +1030,32 @@ extern "C" {
10501030
llama_token_data_array * candidates,
10511031
float temp);
10521032

1033+
/// @details Apply constraints from grammar
1034+
LLAMA_API void llama_sampling_grammar(
1035+
struct llama_sampling * smpl,
1036+
llama_token_data_array * candidates);
1037+
1038+
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
1039+
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
1040+
LLAMA_API void llama_sampling_repetition_penalties(
1041+
struct llama_sampling * smpl,
1042+
llama_token_data_array * candidates,
1043+
const llama_token * last_tokens,
1044+
size_t penalty_last_n,
1045+
float penalty_repeat,
1046+
float penalty_freq,
1047+
float penalty_present);
1048+
1049+
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
1050+
/// @param logits Logits extracted from the original generation context.
1051+
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
1052+
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
1053+
LLAMA_API void llama_sampling_apply_guidance(
1054+
struct llama_sampling * smpl,
1055+
float * logits,
1056+
float * logits_guidance,
1057+
float scale);
1058+
10531059
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
10541060
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
10551061
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1082,21 +1088,14 @@ extern "C" {
10821088
struct llama_sampling * smpl,
10831089
llama_token_data_array * candidates);
10841090

1085-
/// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of smpl.
1091+
/// @details Randomly selects a token from the candidates based on their probabilities
10861092
LLAMA_API llama_token llama_sampling_sample(
10871093
struct llama_sampling * smpl,
10881094
llama_token_data_array * candidates);
10891095

1090-
/// @details Apply constraints from grammar
1091-
LLAMA_API void llama_sampling_grammar(
1092-
const struct llama_sampling * smpl,
1093-
const struct llama_context * ctx,
1094-
llama_token_data_array * candidates);
1095-
10961096
/// @details Accepts the sampled token into the grammar
10971097
LLAMA_API void llama_sampling_accept(
10981098
struct llama_sampling * smpl,
1099-
struct llama_context * ctx,
11001099
llama_token token);
11011100

11021101
//
@@ -1116,8 +1115,8 @@ extern "C" {
11161115
// Performance information
11171116
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
11181117

1119-
LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar);
1120-
LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar);
1118+
LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl);
1119+
LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl);
11211120

11221121
// Print system information
11231122
LLAMA_API const char * llama_print_system_info(void);

src/llama-grammar.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -502,16 +502,16 @@ bool llama_grammar_parser::parse(const char * src) {
502502
return true;
503503
}
504504

505-
void llama_grammar::print(FILE * file) {
505+
void llama_grammar_parser::print(FILE * file) {
506506
try {
507507
std::map<uint32_t, std::string> symbol_id_names;
508-
for (const auto & kv : parser.symbol_ids) {
508+
for (const auto & kv : symbol_ids) {
509509
symbol_id_names[kv.second] = kv.first;
510510
}
511-
for (size_t i = 0, end = parser.rules.size(); i < end; i++) {
511+
for (size_t i = 0, end = rules.size(); i < end; i++) {
512512
// fprintf(file, "%zu: ", i);
513-
// print_rule_binary(file, parser.rules[i]);
514-
print_rule(file, uint32_t(i), parser.rules[i], symbol_id_names);
513+
// print_rule_binary(file, rules[i]);
514+
print_rule(file, uint32_t(i), rules[i], symbol_id_names);
515515
// fprintf(file, "\n");
516516
}
517517
} catch (const std::exception & err) {
@@ -848,7 +848,7 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & gram
848848
return result;
849849
}
850850

851-
void llama_grammar_sample_impl(const struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token_data_array * candidates) {
851+
void llama_grammar_apply_impl(const struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token_data_array * candidates) {
852852
bool allow_eog = false;
853853
for (const auto & stack : grammar.stacks) {
854854
if (stack.empty()) {
@@ -885,7 +885,7 @@ void llama_grammar_sample_impl(const struct llama_grammar & grammar, const struc
885885
}
886886
}
887887

888-
void llama_grammar_accept_token_impl(struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token token) {
888+
void llama_grammar_accept_impl(struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token token) {
889889
if (llama_token_is_eog_impl(vocab, token)) {
890890
for (const auto & stack : grammar.stacks) {
891891
if (stack.empty()) {

src/llama-grammar.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ struct llama_grammar_parser {
109109
const char * parse_rule(const char * src);
110110

111111
bool parse(const char * src);
112+
void print(FILE * file);
112113
};
113114

114115
struct llama_grammar {
@@ -118,14 +119,10 @@ struct llama_grammar {
118119
// buffer for partially generated UTF-8 sequence from accepted tokens
119120
llama_partial_utf8 partial_utf8;
120121

121-
llama_grammar_parser parser;
122-
123122
mutable int64_t t_total_us;
124123

125124
mutable int32_t n_sample;
126125
mutable int32_t n_accept;
127-
128-
void print(FILE * file);
129126
};
130127

131128
//
@@ -138,12 +135,13 @@ void llama_grammar_free_impl(struct llama_grammar * grammar);
138135

139136
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar);
140137

141-
void llama_grammar_sample_impl(
138+
// TODO: move the API below as member functions of llama_grammar
139+
void llama_grammar_apply_impl(
142140
const struct llama_grammar & grammar,
143141
const struct llama_vocab & vocab,
144142
llama_token_data_array * candidates);
145143

146-
void llama_grammar_accept_token_impl(
144+
void llama_grammar_accept_impl(
147145
struct llama_grammar & grammar,
148146
const struct llama_vocab & vocab,
149147
llama_token token);

src/llama-sampling.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "llama-sampling.h"
22

3+
#include "llama-vocab.h"
34
#include "llama-grammar.h"
45

56
#include <algorithm>
@@ -23,7 +24,7 @@ static void llama_log_softmax(float * array, size_t size) {
2324
}
2425
}
2526

26-
llama_sampling::llama_sampling(int32_t n_vocab, const char * grammar_str, const char * grammar_root) : n_vocab(n_vocab) {
27+
llama_sampling::llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) : vocab(vocab) {
2728
if (grammar_str != nullptr) {
2829
grammar = llama_grammar_init_impl(grammar_str, grammar_root);
2930
}
@@ -35,8 +36,8 @@ llama_sampling::~llama_sampling() {
3536
}
3637
}
3738

38-
struct llama_sampling * llama_sampling_init_impl(int32_t n_vocab, const char * grammar_str, const char * grammar_root) {
39-
return new llama_sampling(n_vocab, grammar_str, grammar_root);
39+
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
40+
return new llama_sampling(vocab, grammar_str, grammar_root);
4041
}
4142

4243
void llama_sampling_free_impl(struct llama_sampling * sampling) {
@@ -411,6 +412,12 @@ void llama_sampling_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data
411412
}
412413
}
413414

415+
void llama_sampling_grammar_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) {
416+
if (smpl.grammar) {
417+
llama_grammar_apply_impl(*smpl.grammar, smpl.vocab, candidates);
418+
}
419+
}
420+
414421
void llama_sampling_repetition_penalties_impl(
415422
struct llama_sampling & /*smpl*/,
416423
llama_token_data_array * candidates,
@@ -457,12 +464,12 @@ void llama_sampling_apply_guidance_impl(
457464
float * logits,
458465
float * logits_guidance,
459466
float scale) {
460-
const auto n_vocab = smpl.n_vocab;
467+
const auto n_vocab = smpl.vocab.n_vocab;
461468

462469
llama_log_softmax(logits, n_vocab);
463470
llama_log_softmax(logits_guidance, n_vocab);
464471

465-
for (int i = 0; i < n_vocab; ++i) {
472+
for (uint32_t i = 0; i < n_vocab; ++i) {
466473
auto & l = logits[i];
467474
const auto & g = logits_guidance[i];
468475

@@ -471,7 +478,7 @@ void llama_sampling_apply_guidance_impl(
471478
}
472479

473480
llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
474-
const int32_t n_vocab = float(smpl.n_vocab);
481+
const int32_t n_vocab = float(smpl.vocab.n_vocab);
475482

476483
llama_sampling_softmax_impl(smpl, candidates);
477484

@@ -570,3 +577,11 @@ llama_token llama_sampling_sample_with_rng_impl(struct llama_sampling & smpl, ll
570577
llama_token llama_sampling_sample_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) {
571578
return llama_sampling_sample_with_rng_impl(smpl, candidates, smpl.rng);
572579
}
580+
581+
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token) {
582+
// TODO: implement token storing in history
583+
584+
if (smpl.grammar) {
585+
llama_grammar_accept_impl(*smpl.grammar, smpl.vocab, token);
586+
}
587+
}

src/llama-sampling.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
#include "llama-impl.h"
44
#include "llama-grammar.h"
55

6+
struct llama_vocab;
67
struct llama_grammar;
78

89
struct llama_sampling {
9-
llama_sampling(int32_t n_vocab, const char * grammar_str, const char * grammar_root);
10+
llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
1011
~llama_sampling();
1112

12-
const int32_t n_vocab;
13+
const struct llama_vocab & vocab;
1314

1415
std::mt19937 rng;
1516

@@ -24,10 +25,11 @@ struct llama_sampling {
2425
// internal API
2526
//
2627

27-
struct llama_sampling * llama_sampling_init_impl(int32_t n_vocab, const char * grammar_str, const char * grammar_root);
28+
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
2829

2930
void llama_sampling_free_impl(struct llama_sampling * sampling);
3031

32+
// TODO: move the API below as member functions of llama_sampling
3133
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
3234

3335
void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
@@ -38,6 +40,7 @@ void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_dat
3840
void llama_sampling_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
3941
void llama_sampling_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
4042
void llama_sampling_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp);
43+
void llama_sampling_grammar_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
4144

4245
void llama_sampling_repetition_penalties_impl(
4346
struct llama_sampling & smpl,
@@ -60,3 +63,4 @@ llama_token llama_sampling_sample_greedy_impl (struct llama_sampling & smpl,
6063
llama_token llama_sampling_sample_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng);
6164
llama_token llama_sampling_sample_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
6265

66+
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token);

src/llama-vocab.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct llama_vocab {
1818
tattr attr;
1919
};
2020

21+
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
22+
2123
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
2224
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
2325

@@ -61,8 +63,6 @@ struct llama_vocab {
6163
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
6264
};
6365

64-
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
65-
6666
//
6767
// internal API
6868
//
@@ -75,6 +75,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
7575
bool add_special,
7676
bool parse_special = false);
7777

78+
// TODO: move the API below as member functions of llama_vocab
7879
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
7980

8081
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);

0 commit comments

Comments
 (0)