Skip to content

Commit ee77efe

Browse files
authored
test : add simple grammar parsing tests (#2594)
* adds simple grammar parsing tests * adds cassert header
1 parent f64d44a commit ee77efe

File tree

4 files changed

+255
-1
lines changed

4 files changed

+255
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ poetry.lock
7070
poetry.toml
7171

7272
# Test binaries
73+
tests/test-grammar-parser
7374
tests/test-double-float
7475
tests/test-grad0
7576
tests/test-opt

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test
33

44
# Binaries only useful for tests
5-
TEST_TARGETS = tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
5+
TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
66

77
default: $(BUILD_TARGETS)
88

@@ -412,6 +412,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o
412412
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
413413
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
414414

415+
tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS)
416+
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
417+
415418
tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o common.o $(OBJS)
416419
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
417420

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ llama_add_test(test-quantize-fns.cpp)
1111
llama_add_test(test-quantize-perf.cpp)
1212
llama_add_test(test-sampling.cpp)
1313
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
14+
llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
1415
llama_add_test(test-grad0.cpp) # SLOW
1516
# llama_add_test(test-opt.cpp) # SLOW

tests/test-grammar-parser.cpp

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#ifdef NDEBUG
2+
#undef NDEBUG
3+
#endif
4+
5+
#include "llama.h"
6+
#include "examples/grammar-parser.cpp"
7+
#include <cassert>
8+
9+
int main()
10+
{
11+
grammar_parser::parse_state parsed_grammar;
12+
13+
const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+
14+
expr ::= term ([-+*/] term)*
15+
term ::= [0-9]+)""";
16+
17+
parsed_grammar = grammar_parser::parse(grammar_bytes);
18+
19+
std::vector<std::pair<std::string, uint32_t>> expected = {
20+
{"expr", 2},
21+
{"expr_5", 5},
22+
{"expr_6", 6},
23+
{"root", 0},
24+
{"root_1", 1},
25+
{"root_4", 4},
26+
{"term", 3},
27+
{"term_7", 7},
28+
};
29+
30+
uint32_t index = 0;
31+
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
32+
{
33+
std::string key = it->first;
34+
uint32_t value = it->second;
35+
std::pair<std::string, uint32_t> expected_pair = expected[index];
36+
37+
// pretty print error message before asserting
38+
if (expected_pair.first != key || expected_pair.second != value)
39+
{
40+
fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second);
41+
fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value);
42+
fprintf(stderr, "expected_pair != actual_pair\n");
43+
}
44+
45+
assert(expected_pair.first == key && expected_pair.second == value);
46+
47+
index++;
48+
}
49+
std::vector<llama_grammar_element> expected_rules = {
50+
{LLAMA_GRETYPE_RULE_REF, 4},
51+
{LLAMA_GRETYPE_END, 0},
52+
{LLAMA_GRETYPE_RULE_REF, 2},
53+
{LLAMA_GRETYPE_CHAR, 61},
54+
{LLAMA_GRETYPE_RULE_REF, 3},
55+
{LLAMA_GRETYPE_CHAR, 10},
56+
{LLAMA_GRETYPE_END, 0},
57+
{LLAMA_GRETYPE_RULE_REF, 3},
58+
{LLAMA_GRETYPE_RULE_REF, 6},
59+
{LLAMA_GRETYPE_END, 0},
60+
{LLAMA_GRETYPE_RULE_REF, 7},
61+
{LLAMA_GRETYPE_END, 0},
62+
{LLAMA_GRETYPE_RULE_REF, 1},
63+
{LLAMA_GRETYPE_RULE_REF, 4},
64+
{LLAMA_GRETYPE_ALT, 0},
65+
{LLAMA_GRETYPE_RULE_REF, 1},
66+
{LLAMA_GRETYPE_END, 0},
67+
{LLAMA_GRETYPE_CHAR, 45},
68+
{LLAMA_GRETYPE_CHAR_ALT, 43},
69+
{LLAMA_GRETYPE_CHAR_ALT, 42},
70+
{LLAMA_GRETYPE_CHAR_ALT, 47},
71+
{LLAMA_GRETYPE_RULE_REF, 3},
72+
{LLAMA_GRETYPE_END, 0},
73+
{LLAMA_GRETYPE_RULE_REF, 5},
74+
{LLAMA_GRETYPE_RULE_REF, 6},
75+
{LLAMA_GRETYPE_ALT, 0},
76+
{LLAMA_GRETYPE_END, 0},
77+
{LLAMA_GRETYPE_CHAR, 48},
78+
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
79+
{LLAMA_GRETYPE_RULE_REF, 7},
80+
{LLAMA_GRETYPE_ALT, 0},
81+
{LLAMA_GRETYPE_CHAR, 48},
82+
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
83+
{LLAMA_GRETYPE_END, 0},
84+
};
85+
86+
index = 0;
87+
for (auto rule : parsed_grammar.rules)
88+
{
89+
// compare rule to expected rule
90+
for (uint32_t i = 0; i < rule.size(); i++)
91+
{
92+
llama_grammar_element element = rule[i];
93+
llama_grammar_element expected_element = expected_rules[index];
94+
95+
// pretty print error message before asserting
96+
if (expected_element.type != element.type || expected_element.value != element.value)
97+
{
98+
fprintf(stderr, "index: %d\n", index);
99+
fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
100+
fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value);
101+
fprintf(stderr, "expected_element != actual_element\n");
102+
}
103+
104+
assert(expected_element.type == element.type && expected_element.value == element.value);
105+
index++;
106+
}
107+
}
108+
109+
const char *longer_grammar_bytes = R"""(
110+
root ::= (expr "=" ws term "\n")+
111+
expr ::= term ([-+*/] term)*
112+
term ::= ident | num | "(" ws expr ")" ws
113+
ident ::= [a-z] [a-z0-9_]* ws
114+
num ::= [0-9]+ ws
115+
ws ::= [ \t\n]*
116+
)""";
117+
118+
parsed_grammar = grammar_parser::parse(longer_grammar_bytes);
119+
120+
expected = {
121+
{"expr", 2},
122+
{"expr_6", 6},
123+
{"expr_7", 7},
124+
{"ident", 8},
125+
{"ident_10", 10},
126+
{"num", 9},
127+
{"num_11", 11},
128+
{"root", 0},
129+
{"root_1", 1},
130+
{"root_5", 5},
131+
{"term", 4},
132+
{"ws", 3},
133+
{"ws_12", 12},
134+
};
135+
136+
index = 0;
137+
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
138+
{
139+
std::string key = it->first;
140+
uint32_t value = it->second;
141+
std::pair<std::string, uint32_t> expected_pair = expected[index];
142+
143+
// pretty print error message before asserting
144+
if (expected_pair.first != key || expected_pair.second != value)
145+
{
146+
fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second);
147+
fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value);
148+
fprintf(stderr, "expected_pair != actual_pair\n");
149+
}
150+
151+
assert(expected_pair.first == key && expected_pair.second == value);
152+
153+
index++;
154+
}
155+
expected_rules = {
156+
{LLAMA_GRETYPE_RULE_REF, 5},
157+
{LLAMA_GRETYPE_END, 0},
158+
{LLAMA_GRETYPE_RULE_REF, 2},
159+
{LLAMA_GRETYPE_CHAR, 61},
160+
{LLAMA_GRETYPE_RULE_REF, 3},
161+
{LLAMA_GRETYPE_RULE_REF, 4},
162+
{LLAMA_GRETYPE_CHAR, 10},
163+
{LLAMA_GRETYPE_END, 0},
164+
{LLAMA_GRETYPE_RULE_REF, 4},
165+
{LLAMA_GRETYPE_RULE_REF, 7},
166+
{LLAMA_GRETYPE_END, 0},
167+
{LLAMA_GRETYPE_RULE_REF, 12},
168+
{LLAMA_GRETYPE_END, 0},
169+
{LLAMA_GRETYPE_RULE_REF, 8},
170+
{LLAMA_GRETYPE_ALT, 0},
171+
{LLAMA_GRETYPE_RULE_REF, 9},
172+
{LLAMA_GRETYPE_ALT, 0},
173+
{LLAMA_GRETYPE_CHAR, 40},
174+
{LLAMA_GRETYPE_RULE_REF, 3},
175+
{LLAMA_GRETYPE_RULE_REF, 2},
176+
{LLAMA_GRETYPE_CHAR, 41},
177+
{LLAMA_GRETYPE_RULE_REF, 3},
178+
{LLAMA_GRETYPE_END, 0},
179+
{LLAMA_GRETYPE_RULE_REF, 1},
180+
{LLAMA_GRETYPE_RULE_REF, 5},
181+
{LLAMA_GRETYPE_ALT, 0},
182+
{LLAMA_GRETYPE_RULE_REF, 1},
183+
{LLAMA_GRETYPE_END, 0},
184+
{LLAMA_GRETYPE_CHAR, 45},
185+
{LLAMA_GRETYPE_CHAR_ALT, 43},
186+
{LLAMA_GRETYPE_CHAR_ALT, 42},
187+
{LLAMA_GRETYPE_CHAR_ALT, 47},
188+
{LLAMA_GRETYPE_RULE_REF, 4},
189+
{LLAMA_GRETYPE_END, 0},
190+
{LLAMA_GRETYPE_RULE_REF, 6},
191+
{LLAMA_GRETYPE_RULE_REF, 7},
192+
{LLAMA_GRETYPE_ALT, 0},
193+
{LLAMA_GRETYPE_END, 0},
194+
{LLAMA_GRETYPE_CHAR, 97},
195+
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
196+
{LLAMA_GRETYPE_RULE_REF, 10},
197+
{LLAMA_GRETYPE_RULE_REF, 3},
198+
{LLAMA_GRETYPE_END, 0},
199+
{LLAMA_GRETYPE_RULE_REF, 11},
200+
{LLAMA_GRETYPE_RULE_REF, 3},
201+
{LLAMA_GRETYPE_END, 0},
202+
{LLAMA_GRETYPE_CHAR, 97},
203+
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
204+
{LLAMA_GRETYPE_CHAR_ALT, 48},
205+
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
206+
{LLAMA_GRETYPE_CHAR_ALT, 95},
207+
{LLAMA_GRETYPE_RULE_REF, 10},
208+
{LLAMA_GRETYPE_ALT, 0},
209+
{LLAMA_GRETYPE_END, 0},
210+
{LLAMA_GRETYPE_CHAR, 48},
211+
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
212+
{LLAMA_GRETYPE_RULE_REF, 11},
213+
{LLAMA_GRETYPE_ALT, 0},
214+
{LLAMA_GRETYPE_CHAR, 48},
215+
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
216+
{LLAMA_GRETYPE_END, 0},
217+
{LLAMA_GRETYPE_CHAR, 32},
218+
{LLAMA_GRETYPE_CHAR_ALT, 9},
219+
{LLAMA_GRETYPE_CHAR_ALT, 10},
220+
{LLAMA_GRETYPE_RULE_REF, 12},
221+
{LLAMA_GRETYPE_ALT, 0},
222+
{LLAMA_GRETYPE_END, 0},
223+
};
224+
225+
index = 0;
226+
for (auto rule : parsed_grammar.rules)
227+
{
228+
// compare rule to expected rule
229+
for (uint32_t i = 0; i < rule.size(); i++)
230+
{
231+
llama_grammar_element element = rule[i];
232+
llama_grammar_element expected_element = expected_rules[index];
233+
234+
// pretty print error message before asserting
235+
if (expected_element.type != element.type || expected_element.value != element.value)
236+
{
237+
fprintf(stderr, "index: %d\n", index);
238+
fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
239+
fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value);
240+
fprintf(stderr, "expected_element != actual_element\n");
241+
}
242+
243+
assert(expected_element.type == element.type && expected_element.value == element.value);
244+
index++;
245+
}
246+
}
247+
248+
return 0;
249+
}

0 commit comments

Comments
 (0)