Skip to content

Commit 675f305

Browse files
committed
llama : move grammar code into llama-grammar
ggml-ci
1 parent 0ddc8e3 commit 675f305

File tree

12 files changed

+742
-672
lines changed

12 files changed

+742
-672
lines changed

Makefile

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,8 @@ OBJ_GGML += \
876876

877877
OBJ_LLAMA = \
878878
src/llama.o \
879+
src/llama-vocab.o \
880+
src/llama-grammar.o \
879881
src/llama-sampling.o \
880882
src/unicode.o \
881883
src/unicode-data.o
@@ -1066,6 +1068,20 @@ src/llama.o: \
10661068
ggml/include/ggml-backend.h
10671069
$(CXX) $(CXXFLAGS) -c $< -o $@
10681070

1071+
src/llama-vocab.o: \
1072+
src/llama-vocab.cpp \
1073+
src/llama-vocab.h \
1074+
src/llama-impl.h \
1075+
include/llama.h
1076+
$(CXX) $(CXXFLAGS) -c $< -o $@
1077+
1078+
src/llama-grammar.o: \
1079+
src/llama-grammar.cpp \
1080+
src/llama-grammar.h \
1081+
src/llama-impl.h \
1082+
include/llama.h
1083+
$(CXX) $(CXXFLAGS) -c $< -o $@
1084+
10691085
src/llama-sampling.o: \
10701086
src/llama-sampling.cpp \
10711087
src/llama-sampling.h \
@@ -1448,7 +1464,7 @@ run-benchmark-matmult: llama-benchmark-matmult
14481464
.PHONY: run-benchmark-matmult swift
14491465

14501466
tests/test-llama-grammar: tests/test-llama-grammar.cpp \
1451-
$(OBJ_GGML) $(OBJ_COMMON) src/unicode.o src/unicode-data.o
1467+
$(OBJ_ALL)
14521468
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
14531469
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
14541470

common/sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
330330
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
331331

332332
// Apply grammar constraints to the single token
333-
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
333+
llama_grammar_sample(ctx_main, &single_token_data_array, ctx_sampling->grammar);
334334

335335
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
336336
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
421421

422422
// apply grammar checks before sampling logic
423423
if (apply_grammar && ctx_sampling->grammar != NULL) {
424-
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
424+
llama_grammar_sample(ctx_main, &cur_p, ctx_sampling->grammar);
425425
}
426426

427427
return cur_p;

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,23 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
1616
auto decoded = decode_utf8(input_str, {});
1717
const auto & code_points = decoded.first;
1818

19+
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
20+
1921
size_t pos = 0;
2022
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
21-
auto prev_stacks = grammar->stacks;
22-
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
23-
if (grammar->stacks.empty()) {
23+
const llama_grammar_rules & prev_rules = llama_grammar_get_rules (grammar);
24+
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
25+
llama_grammar_accept(prev_rules, prev_stacks, *it, cur_stacks);
26+
if (cur_stacks.empty()) {
2427
error_pos = pos;
2528
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
26-
grammar->stacks = prev_stacks;
29+
cur_stacks = prev_stacks;
2730
return false;
2831
}
2932
++pos;
3033
}
3134

32-
for (const auto & stack : grammar->stacks) {
35+
for (const auto & stack : cur_stacks) {
3336
if (stack.empty()) {
3437
return true;
3538
}

include/llama.h

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,18 @@ extern "C" {
10031003

10041004
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
10051005

1006+
/// @details Apply constraints from grammar
1007+
LLAMA_API void llama_grammar_sample(
1008+
struct llama_context * ctx,
1009+
llama_token_data_array * candidates,
1010+
const struct llama_grammar * grammar);
1011+
1012+
/// @details Accepts the sampled token into the grammar
1013+
LLAMA_API void llama_grammar_accept_token(
1014+
struct llama_context * ctx,
1015+
struct llama_grammar * grammar,
1016+
llama_token token);
1017+
10061018
//
10071019
// Sampling functions
10081020
//
@@ -1121,18 +1133,6 @@ extern "C" {
11211133
struct llama_context * ctx,
11221134
llama_token_data_array * candidates);
11231135

1124-
/// @details Apply constraints from grammar
1125-
LLAMA_API void llama_sample_grammar(
1126-
struct llama_context * ctx,
1127-
llama_token_data_array * candidates,
1128-
const struct llama_grammar * grammar);
1129-
1130-
/// @details Accepts the sampled token into the grammar
1131-
LLAMA_API void llama_grammar_accept_token(
1132-
struct llama_context * ctx,
1133-
struct llama_grammar * grammar,
1134-
llama_token token);
1135-
11361136
//
11371137
// Model split
11381138
//
@@ -1175,38 +1175,41 @@ extern "C" {
11751175

11761176
struct ggml_tensor;
11771177

1178+
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1179+
struct llama_context * ctx
1180+
);
1181+
11781182
struct llama_partial_utf8 {
11791183
uint32_t value; // bit value so far (unshifted)
11801184
int n_remain; // num bytes remaining; -1 indicates invalid sequence
11811185
};
11821186

1183-
struct llama_grammar {
1184-
const std::vector<std::vector<llama_grammar_element>> rules;
1185-
std::vector<std::vector<const llama_grammar_element *>> stacks;
1186-
1187-
// buffer for partially generated UTF-8 sequence from accepted tokens
1188-
llama_partial_utf8 partial_utf8;
1189-
};
1190-
11911187
struct llama_grammar_candidate {
11921188
size_t index;
11931189
const uint32_t * code_points;
11941190
llama_partial_utf8 partial_utf8;
11951191
};
11961192

1197-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1198-
struct llama_context * ctx
1199-
);
1193+
using llama_grammar_rules = std::vector<std::vector<llama_grammar_element>>;
1194+
using llama_grammar_stacks = std::vector<std::vector<const llama_grammar_element *>>;
1195+
1196+
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
1197+
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
12001198

12011199
void llama_grammar_accept(
1202-
const std::vector<std::vector<llama_grammar_element>> & rules,
1203-
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1204-
const uint32_t chr,
1205-
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
1200+
const llama_grammar_rules & rules,
1201+
const llama_grammar_stacks & stacks,
1202+
const uint32_t chr,
1203+
llama_grammar_stacks & new_stacks);
1204+
1205+
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
1206+
const std::vector<std::vector<llama_grammar_element>> & rules,
1207+
const std::vector<const llama_grammar_element *> & stack,
1208+
const std::vector<llama_grammar_candidate> & candidates);
12061209

12071210
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
12081211
const std::string & src,
1209-
llama_partial_utf8 partial_start);
1212+
llama_partial_utf8 partial_start);
12101213

12111214
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
12121215
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.

src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ endif()
1414
add_library(llama
1515
../include/llama.h
1616
llama.cpp
17+
llama-vocab.cpp
18+
llama-grammar.cpp
1719
llama-sampling.cpp
1820
unicode.h
1921
unicode.cpp

0 commit comments

Comments
 (0)