Skip to content

Commit 80596fc

Browse files
committed
Add simple sampler function with grammar
1 parent ba11eb9 commit 80596fc

File tree

3 files changed

+136
-22
lines changed

3 files changed

+136
-22
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,9 @@ llama_token llama_sample_token(
871871
const std::vector<llama_token> & last_tokens,
872872
std::vector<llama_token_data> & candidates,
873873
int idx) {
874+
875+
LOG("idx sample_token: %d\n", idx);
876+
874877
const int n_ctx = llama_n_ctx(ctx);
875878
const int n_vocab = llama_n_vocab(ctx);
876879

examples/grammar/grammar.cpp

Lines changed: 106 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,114 @@
11
#include "grammar.h"
2-
#include <unordered_map>
3-
#include <string>
4-
#include <vector>
2+
#include <stdlib.h>
53

64
struct llama_grammar * llama_cached_parse_grammar(const char * grammar_str) {
7-
static std::unordered_map<std::string, grammar_parser::parse_state> parsed_grammar_cache;
8-
std::string key = grammar_str;
9-
10-
auto it = parsed_grammar_cache.find(key);
11-
grammar_parser::parse_state parsed_grammar;
12-
if (it != parsed_grammar_cache.end()) {
13-
// Use cached parsed grammar
14-
parsed_grammar = it->second;
15-
} else {
16-
// Parse and cache the result
17-
parsed_grammar = grammar_parser::parse(grammar_str);
18-
parsed_grammar_cache[key] = parsed_grammar;
19-
20-
// Optionally print the grammar
21-
grammar_parser::print_grammar(stderr, parsed_grammar);
5+
static std::unordered_map<std::string, grammar_parser::parse_state> parsed_grammar_cache;
6+
std::string key = grammar_str;
7+
8+
auto it = parsed_grammar_cache.find(key);
9+
grammar_parser::parse_state parsed_grammar;
10+
if (it != parsed_grammar_cache.end()) {
11+
// Use cached parsed grammar
12+
parsed_grammar = it->second;
13+
} else {
14+
// Parse and cache the result
15+
parsed_grammar = grammar_parser::parse(grammar_str);
16+
parsed_grammar_cache[key] = parsed_grammar;
17+
18+
// Optionally print the grammar
19+
grammar_parser::print_grammar(stderr, parsed_grammar);
20+
}
21+
22+
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
23+
24+
struct llama_grammar * grammar = NULL;
25+
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
26+
27+
return grammar;
28+
}
29+
30+
struct llama_sampler_params llama_sampler_default_params() {
31+
struct llama_sampler_params result = {
32+
0.80f, // temp;
33+
1.10f, // repeat_penalty
34+
64, // last_n_repeat
35+
0.00f, // frequency_penalty
36+
0.00f, // presence_penalty
37+
2, // mirostat
38+
5.00f, // mirostat_tau
39+
0.10f, // mirostat_eta
40+
};
41+
return result;
42+
}
43+
44+
llama_token llama_grammar_sample_token(struct llama_context * ctx,
45+
struct llama_grammar * grammar,
46+
struct llama_sampler_params params,
47+
struct llama_token_data_array * cur_p,
48+
bool reset) {
49+
50+
const int n_ctx = llama_n_ctx(ctx);
51+
52+
static std::vector<llama_token> last_tokens(n_ctx);
53+
std::fill(last_tokens.begin(), last_tokens.end(), 0);
54+
55+
if (reset) {
56+
// Clear last_tokens vector
57+
last_tokens.clear();
58+
last_tokens.resize(n_ctx, 0);
59+
}
60+
61+
const float temp = params.temp;
62+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
63+
const float repeat_penalty = params.repeat_penalty;
64+
const float alpha_presence = params.presence_penalty;
65+
const float alpha_frequency = params.frequency_penalty;
66+
const int mirostat = params.mirostat;
67+
const float mirostat_tau = params.mirostat_tau;
68+
const float mirostat_eta = params.mirostat_eta;
69+
70+
llama_token id = 0;
71+
72+
// apply penalties
73+
if (!last_tokens.empty()) {
74+
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
75+
76+
llama_sample_repetition_penalty(ctx, cur_p,
77+
last_tokens.data() + last_tokens.size() - last_n_repeat,
78+
last_n_repeat, repeat_penalty);
79+
llama_sample_frequency_and_presence_penalties(ctx, cur_p,
80+
last_tokens.data() + last_tokens.size() - last_n_repeat,
81+
last_n_repeat, alpha_frequency, alpha_presence);
82+
83+
}
84+
85+
if (grammar != NULL) {
86+
llama_sample_grammar(ctx, cur_p, grammar);
87+
}
88+
89+
if (temp <= 0) {
90+
// Greedy sampling
91+
id = llama_sample_token_greedy(ctx, cur_p);
92+
} else {
93+
if (mirostat == 1) {
94+
static float mirostat_mu = 2.0f * mirostat_tau;
95+
const int mirostat_m = 100;
96+
llama_sample_temperature(ctx, cur_p, temp);
97+
id = llama_sample_token_mirostat(ctx, cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
98+
} else if (mirostat == 2) {
99+
static float mirostat_mu = 2.0f * mirostat_tau;
100+
llama_sample_temperature(ctx, cur_p, temp);
101+
id = llama_sample_token_mirostat_v2(ctx, cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
22102
}
103+
}
104+
// printf("`%d`", candidates_p.size);
23105

24-
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
106+
if (grammar != NULL) {
107+
llama_grammar_accept_token(ctx, grammar, id);
108+
}
25109

26-
struct llama_grammar * grammar = NULL;
27-
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
110+
last_tokens.erase(last_tokens.begin());
111+
last_tokens.push_back(id);
28112

29-
return grammar;
113+
return id;
30114
}

examples/grammar/grammar.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,42 @@
11
#ifndef GRAMMAR_H
22
#define GRAMMAR_H
33

4+
#include <string>
5+
#include <vector>
6+
#include <unordered_map>
7+
#include <stddef.h>
8+
#include <stdint.h>
9+
#include <stdbool.h>
10+
411
#include "llama.h"
512
#include "grammar-parser.h"
613

714
#ifdef __cplusplus
815
extern "C" {
916
#endif
17+
struct llama_sampler_params {
18+
float temp;
19+
float repeat_penalty;
20+
int32_t repeat_last_n;
21+
float frequency_penalty;
22+
float presence_penalty;
23+
int32_t mirostat;
24+
float mirostat_tau;
25+
float mirostat_eta;
26+
};
27+
28+
llama_sampler_params llama_sampler_default_params();
1029

1130
struct llama_grammar * llama_cached_parse_grammar(const char * grammar_str);
1231

32+
llama_token llama_grammar_sample_token(llama_context * ctx,
33+
llama_grammar * grammar,
34+
llama_sampler_params params,
35+
llama_token_data_array * cur_p,
36+
bool reset);
37+
38+
39+
1340
#ifdef __cplusplus
1441
}
1542
#endif

0 commit comments

Comments
 (0)