Skip to content

Commit d6902ce

Browse files
[V0][V1][Core] Add outlines integration for V1, and update V0 integration. (#15975)
Signed-off-by: Nathan Hoos <thwackyy.y@gmail.com>
1 parent 5e53c89 commit d6902ce

File tree

13 files changed

+807
-464
lines changed

13 files changed

+807
-464
lines changed

requirements/common.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
2121
tiktoken >= 0.6.0 # Required for DBRX tokenizer
2222
lm-format-enforcer >= 0.10.11, < 0.11
2323
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
24-
outlines == 0.1.11
24+
outlines_core == 0.2.10
25+
# required for outlines backend disk cache
26+
diskcache == 5.6.3
2527
lark == 1.2.2
2628
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
2729
typing_extensions >= 4.10

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1717

1818
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
19-
GUIDED_DECODING_BACKENDS = [
19+
20+
# Separate backends which support grammars vs ones
21+
# which only support regex based constraints in tests.
22+
GRAMMAR_DECODING_BACKENDS = [
2023
# (backend, disable_any_whitespace),
21-
("outlines", False),
2224
("lm-format-enforcer", False),
2325
("xgrammar", True),
2426
("guidance", True),
2527
]
2628

29+
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
30+
2731

2832
@pytest.fixture(scope="module")
2933
def llm():
@@ -39,7 +43,7 @@ def llm():
3943

4044
@pytest.mark.skip_global_cleanup
4145
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
42-
GUIDED_DECODING_BACKENDS)
46+
ALL_DECODING_BACKENDS)
4347
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
4448
disable_any_whitespace: bool):
4549
sampling_params = SamplingParams(
@@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
4953
regex=sample_regex,
5054
backend=guided_decoding_backend,
5155
disable_any_whitespace=disable_any_whitespace))
56+
5257
outputs = llm.generate(prompts=[
5358
f"Give an example IPv4 address with this regex: {sample_regex}"
5459
] * 2,
@@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
6974

7075
@pytest.mark.skip_global_cleanup
7176
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
72-
GUIDED_DECODING_BACKENDS)
77+
ALL_DECODING_BACKENDS)
7378
def test_guided_json_completion(sample_json_schema, llm,
7479
guided_decoding_backend: str,
7580
disable_any_whitespace: bool):
@@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
103108

104109
@pytest.mark.skip_global_cleanup
105110
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
106-
GUIDED_DECODING_BACKENDS)
111+
ALL_DECODING_BACKENDS)
107112
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
108113
guided_decoding_backend: str,
109114
disable_any_whitespace: bool):
@@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
138143

139144
@pytest.mark.skip_global_cleanup
140145
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
141-
GUIDED_DECODING_BACKENDS)
146+
ALL_DECODING_BACKENDS)
142147
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
143148
guided_decoding_backend: str,
144149
disable_any_whitespace: bool):
@@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
173178

174179
@pytest.mark.skip_global_cleanup
175180
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
176-
GUIDED_DECODING_BACKENDS)
181+
ALL_DECODING_BACKENDS)
177182
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
178183
guided_decoding_backend: str,
179184
disable_any_whitespace: bool):
@@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
218223

219224
@pytest.mark.skip_global_cleanup
220225
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
221-
GUIDED_DECODING_BACKENDS)
226+
ALL_DECODING_BACKENDS)
222227
def test_guided_choice_completion(sample_guided_choice, llm,
223228
guided_decoding_backend: str,
224229
disable_any_whitespace: bool):
@@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
248253

249254
@pytest.mark.skip_global_cleanup
250255
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
251-
GUIDED_DECODING_BACKENDS)
256+
GRAMMAR_DECODING_BACKENDS)
252257
def test_guided_grammar(sample_sql_statements, llm,
253258
guided_decoding_backend: str,
254259
disable_any_whitespace: bool):
@@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
344349

345350
@pytest.mark.skip_global_cleanup
346351
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
347-
GUIDED_DECODING_BACKENDS)
352+
GRAMMAR_DECODING_BACKENDS)
348353
def test_guided_json_object(llm, guided_decoding_backend: str,
349354
disable_any_whitespace: bool):
350355
sampling_params = SamplingParams(
@@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
377382

378383
# Parse to verify it is valid JSON
379384
parsed_json = json.loads(generated_text)
380-
assert isinstance(parsed_json, dict)
385+
# A list is not what was intended, but is still valid
386+
# json.
387+
assert isinstance(parsed_json, (dict, list))
381388

382389

383390
class CarType(str, Enum):
@@ -395,7 +402,7 @@ class CarDescription(BaseModel):
395402

396403
@pytest.mark.skip_global_cleanup
397404
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
398-
GUIDED_DECODING_BACKENDS)
405+
ALL_DECODING_BACKENDS)
399406
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
400407
disable_any_whitespace: bool):
401408
json_schema = CarDescription.model_json_schema()
@@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
427434

428435
@pytest.mark.skip_global_cleanup
429436
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
430-
GUIDED_DECODING_BACKENDS)
437+
ALL_DECODING_BACKENDS)
431438
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
432439
disable_any_whitespace: bool):
433440
sample_output_schema = {

tests/model_executor/test_guided_processors.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
4646
whitespace_pattern=None,
4747
reasoner=None)
4848

49-
token_ids = zephyr_7B_tokenzer.encode(
50-
f"Give an example IPv4 address with this regex: {sample_regex}")
5149
tensor = torch.rand(32000)
5250
original_tensor = torch.clone(tensor)
53-
regex_LP(token_ids, tensor)
51+
tensor = regex_LP([], tensor)
5452
assert tensor.shape == original_tensor.shape
5553
assert not torch.allclose(tensor, original_tensor)
5654

57-
token_ids = zephyr_7B_tokenzer.encode(
58-
f"Give an employee profile that fits this schema: {sample_json_schema}"
59-
)
6055
tensor = torch.rand(32000)
6156
original_tensor = torch.clone(tensor)
62-
json_LP(token_ids, tensor)
57+
tensor = json_LP([], tensor)
6358
assert tensor.shape == original_tensor.shape
6459
assert not torch.allclose(tensor, original_tensor)
6560

@@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
8176
seed=0,
8277
dtype="bfloat16",
8378
)
84-
token_ids = zephyr_7B_tokenzer.encode(
85-
f"Give an example IPv4 address with this regex: {sample_regex}")
8679
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
8780

8881
regex_lp = get_local_guided_decoding_logits_processor(
@@ -92,21 +85,19 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
9285
assert regex_lp is not None
9386
tensor = torch.rand(32000)
9487
original_tensor = torch.clone(tensor)
95-
tensor = regex_lp(token_ids, tensor)
88+
# allowed tokens at state 0
89+
tensor = regex_lp([], tensor)
9690
assert tensor.shape == original_tensor.shape
9791
assert not torch.allclose(tensor, original_tensor)
9892

99-
token_ids = zephyr_7B_tokenzer.encode(
100-
f"Give an employee profile that fits this schema: {sample_json_schema}"
101-
)
10293
json_request = GuidedDecodingParams(json=sample_json_schema,
10394
backend=backend)
10495
json_lp = await get_guided_decoding_logits_processor(
10596
json_request, zephyr_7B_tokenzer, config)
10697
assert json_lp is not None
10798
tensor = torch.rand(32000)
10899
original_tensor = torch.clone(tensor)
109-
tensor = json_lp(token_ids, tensor)
100+
tensor = json_lp([], tensor)
110101
assert tensor.shape == original_tensor.shape
111102
assert not torch.allclose(tensor, original_tensor)
112103

@@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning(
130121
dtype="bfloat16",
131122
)
132123
token_ids = deepseek_r1_qwen_tokenizer.encode(
133-
f"Give an example IPv4 address with this regex: {sample_regex}."
134124
"<think>here is the thinking process")
135125
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
136126

@@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning(
141131
regex_request, deepseek_r1_qwen_tokenizer, config,
142132
reasoning_backend)
143133
assert regex_lp is not None
144-
tensor = torch.rand(32000)
134+
tensor = torch.rand(151664)
145135
original_tensor = torch.clone(tensor)
146136
tensor = regex_lp(token_ids, tensor)
147137
assert tensor.shape == original_tensor.shape
148138
assert torch.allclose(tensor, original_tensor)
149139

150140
token_ids = deepseek_r1_qwen_tokenizer.encode(
151-
f"Give an employee profile that fits this schema: {sample_json_schema}."
152141
"<think>here is the thinking process")
153142
json_request = GuidedDecodingParams(json=sample_json_schema,
154143
backend=backend)
@@ -158,16 +147,15 @@ async def test_guided_logits_processor_with_reasoning(
158147
await get_guided_decoding_logits_processor(
159148
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
160149
assert json_lp is not None
161-
tensor = torch.rand(32000)
150+
tensor = torch.rand(151664)
162151
original_tensor = torch.clone(tensor)
163152
tensor = json_lp(token_ids, tensor)
164153
assert tensor.shape == original_tensor.shape
165154
assert torch.allclose(tensor, original_tensor)
166155

167156
# Thinking is over, so the tensor should change.
168157
token_ids = deepseek_r1_qwen_tokenizer.encode(
169-
f"Give an employee profile that fits this schema: {sample_json_schema}."
170-
"<think>here is the thinking process</think> Then")
158+
"<think>here is the thinking process</think>")
171159
json_request = GuidedDecodingParams(json=sample_json_schema,
172160
backend=backend)
173161
json_lp = get_local_guided_decoding_logits_processor(
@@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning(
176164
await get_guided_decoding_logits_processor(
177165
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
178166
assert json_lp is not None
179-
tensor = torch.rand(32000)
167+
tensor = torch.rand(151664)
180168
original_tensor = torch.clone(tensor)
181169
tensor = json_lp(token_ids, tensor)
182170
assert tensor.shape == original_tensor.shape

tests/tool_use/test_tool_choice_required.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
7272
assert isinstance(schema, dict)
7373

7474
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
75-
from outlines_core.fsm.json_schema import build_regex_from_schema
75+
from outlines_core.json_schema import build_regex_from_schema
7676
regex = build_regex_from_schema(json.dumps(schema))
7777
compiled = re.compile(regex)
7878
matches = compiled.fullmatch(json.dumps(sample_output)) is not None

0 commit comments

Comments
 (0)