16
16
from vllm .sampling_params import GuidedDecodingParams , SamplingParams
17
17
18
18
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
19
- GUIDED_DECODING_BACKENDS = [
19
+
20
+ # Separate backends which support grammars vs ones
21
+ # which only support regex based constraints in tests.
22
+ GRAMMAR_DECODING_BACKENDS = [
20
23
# (backend, disable_any_whitespace),
21
- ("outlines" , False ),
22
24
("lm-format-enforcer" , False ),
23
25
("xgrammar" , True ),
24
26
("guidance" , True ),
25
27
]
26
28
29
+ ALL_DECODING_BACKENDS = ([("outlines" , False )] + GRAMMAR_DECODING_BACKENDS )
30
+
27
31
28
32
@pytest .fixture (scope = "module" )
29
33
def llm ():
@@ -39,7 +43,7 @@ def llm():
39
43
40
44
@pytest .mark .skip_global_cleanup
41
45
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
42
- GUIDED_DECODING_BACKENDS )
46
+ ALL_DECODING_BACKENDS )
43
47
def test_guided_regex (sample_regex , llm , guided_decoding_backend : str ,
44
48
disable_any_whitespace : bool ):
45
49
sampling_params = SamplingParams (
@@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
49
53
regex = sample_regex ,
50
54
backend = guided_decoding_backend ,
51
55
disable_any_whitespace = disable_any_whitespace ))
56
+
52
57
outputs = llm .generate (prompts = [
53
58
f"Give an example IPv4 address with this regex: { sample_regex } "
54
59
] * 2 ,
@@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
69
74
70
75
@pytest .mark .skip_global_cleanup
71
76
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
72
- GUIDED_DECODING_BACKENDS )
77
+ ALL_DECODING_BACKENDS )
73
78
def test_guided_json_completion (sample_json_schema , llm ,
74
79
guided_decoding_backend : str ,
75
80
disable_any_whitespace : bool ):
@@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
103
108
104
109
@pytest .mark .skip_global_cleanup
105
110
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
106
- GUIDED_DECODING_BACKENDS )
111
+ ALL_DECODING_BACKENDS )
107
112
def test_guided_complex_json_completion (sample_complex_json_schema , llm ,
108
113
guided_decoding_backend : str ,
109
114
disable_any_whitespace : bool ):
@@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
138
143
139
144
@pytest .mark .skip_global_cleanup
140
145
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
141
- GUIDED_DECODING_BACKENDS )
146
+ ALL_DECODING_BACKENDS )
142
147
def test_guided_definition_json_completion (sample_definition_json_schema , llm ,
143
148
guided_decoding_backend : str ,
144
149
disable_any_whitespace : bool ):
@@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
173
178
174
179
@pytest .mark .skip_global_cleanup
175
180
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
176
- GUIDED_DECODING_BACKENDS )
181
+ ALL_DECODING_BACKENDS )
177
182
def test_guided_enum_json_completion (sample_enum_json_schema , llm ,
178
183
guided_decoding_backend : str ,
179
184
disable_any_whitespace : bool ):
@@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
218
223
219
224
@pytest .mark .skip_global_cleanup
220
225
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
221
- GUIDED_DECODING_BACKENDS )
226
+ ALL_DECODING_BACKENDS )
222
227
def test_guided_choice_completion (sample_guided_choice , llm ,
223
228
guided_decoding_backend : str ,
224
229
disable_any_whitespace : bool ):
@@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
248
253
249
254
@pytest .mark .skip_global_cleanup
250
255
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
251
- GUIDED_DECODING_BACKENDS )
256
+ GRAMMAR_DECODING_BACKENDS )
252
257
def test_guided_grammar (sample_sql_statements , llm ,
253
258
guided_decoding_backend : str ,
254
259
disable_any_whitespace : bool ):
@@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
344
349
345
350
@pytest .mark .skip_global_cleanup
346
351
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
347
- GUIDED_DECODING_BACKENDS )
352
+ GRAMMAR_DECODING_BACKENDS )
348
353
def test_guided_json_object (llm , guided_decoding_backend : str ,
349
354
disable_any_whitespace : bool ):
350
355
sampling_params = SamplingParams (
@@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
377
382
378
383
# Parse to verify it is valid JSON
379
384
parsed_json = json .loads (generated_text )
380
- assert isinstance (parsed_json , dict )
385
+ # A list is not what was intended, but is still valid
386
+ # json.
387
+ assert isinstance (parsed_json , (dict , list ))
381
388
382
389
383
390
class CarType (str , Enum ):
@@ -395,7 +402,7 @@ class CarDescription(BaseModel):
395
402
396
403
@pytest .mark .skip_global_cleanup
397
404
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
398
- GUIDED_DECODING_BACKENDS )
405
+ ALL_DECODING_BACKENDS )
399
406
def test_guided_json_completion_with_enum (llm , guided_decoding_backend : str ,
400
407
disable_any_whitespace : bool ):
401
408
json_schema = CarDescription .model_json_schema ()
@@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
427
434
428
435
@pytest .mark .skip_global_cleanup
429
436
@pytest .mark .parametrize ("guided_decoding_backend,disable_any_whitespace" ,
430
- GUIDED_DECODING_BACKENDS )
437
+ ALL_DECODING_BACKENDS )
431
438
def test_guided_number_range_json_completion (llm , guided_decoding_backend : str ,
432
439
disable_any_whitespace : bool ):
433
440
sample_output_schema = {
0 commit comments