Skip to content

Commit 3cf5072

Browse files
aarnphmnjhill
authored andcommitted
[Structured Outputs][V1] Skipping with models doesn't contain tokenizers (vllm-project#20365)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Nick Hill <nhill@redhat.com>
1 parent c991d18 commit 3cf5072

File tree

4 files changed

+128
-31
lines changed

4 files changed

+128
-31
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
1010
SchedulerConfig, SpeculativeConfig, VllmConfig)
1111
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
12-
from vllm.sampling_params import SamplingParams
12+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1313
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
1414
from vllm.v1.core.sched.scheduler import Scheduler
1515
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1616
KVCacheGroupSpec)
1717
from vllm.v1.outputs import ModelRunnerOutput
1818
from vllm.v1.request import Request, RequestStatus
1919
from vllm.v1.structured_output import StructuredOutputManager
20+
from vllm.v1.structured_output.request import StructuredOutputRequest
2021

2122
EOS_TOKEN_ID = 50256
2223

@@ -33,6 +34,7 @@ def create_scheduler(
3334
block_size: int = 16,
3435
max_model_len: Optional[int] = None,
3536
num_speculative_tokens: Optional[int] = None,
37+
skip_tokenizer_init: bool = False,
3638
) -> Scheduler:
3739
'''Create scheduler under test.
3840
@@ -65,6 +67,7 @@ def create_scheduler(
6567
trust_remote_code=True,
6668
dtype="float16",
6769
seed=42,
70+
skip_tokenizer_init=skip_tokenizer_init,
6871
)
6972
# Cache config, optionally force APC
7073
kwargs_cache = ({} if enable_prefix_caching is None else {
@@ -186,7 +189,7 @@ def test_get_num_unfinished_requests():
186189
])
187190
def test_schedule(enable_prefix_caching: Optional[bool],
188191
prompt_logprobs: Optional[int]):
189-
'''Test scheduling.
192+
'''Test scheduling.
190193
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
191194
'''
192195
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
@@ -1408,7 +1411,7 @@ def create_requests_with_priority(
14081411

14091412

14101413
def test_priority_scheduling_basic_ordering():
1411-
"""Test that requests are scheduled in priority order
1414+
"""Test that requests are scheduled in priority order
14121415
(lower value = higher priority)."""
14131416
scheduler = create_scheduler_with_priority()
14141417

@@ -1437,7 +1440,7 @@ def test_priority_scheduling_basic_ordering():
14371440

14381441

14391442
def test_priority_scheduling_arrival_time_tiebreaker():
1440-
"""Test that arrival time is used
1443+
"""Test that arrival time is used
14411444
as tiebreaker when priorities are equal."""
14421445
scheduler = create_scheduler_with_priority()
14431446

@@ -1495,7 +1498,7 @@ def test_priority_scheduling_mixed_priority_and_arrival():
14951498

14961499

14971500
def test_priority_scheduling_preemption():
1498-
"""Test that priority scheduling preempts
1501+
"""Test that priority scheduling preempts
14991502
lower priority requests when memory is constrained."""
15001503
# Create scheduler with very limited memory to force preemption
15011504
scheduler = create_scheduler_with_priority(
@@ -1576,7 +1579,7 @@ def test_priority_scheduling_preemption():
15761579

15771580

15781581
def test_priority_scheduling_no_preemption_when_space_available():
1579-
"""Test that preemption doesn't happen
1582+
"""Test that preemption doesn't happen
15801583
when there's space for new requests."""
15811584
scheduler = create_scheduler_with_priority(
15821585
max_num_seqs=3, # Allow 3 concurrent requests
@@ -1626,7 +1629,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
16261629

16271630

16281631
def test_priority_scheduling_preemption_victim_selection():
1629-
"""Test that the correct victim is selected for
1632+
"""Test that the correct victim is selected for
16301633
preemption based on priority and arrival time."""
16311634
# This test verifies the priority-based victim selection logic
16321635
# by checking the waiting queue order after adding requests with different
@@ -1743,7 +1746,7 @@ def test_priority_scheduling_waiting_queue_order():
17431746

17441747

17451748
def test_priority_scheduling_fcfs_fallback():
1746-
"""Test that FCFS behavior is maintained when all
1749+
"""Test that FCFS behavior is maintained when all
17471750
requests have same priority."""
17481751
scheduler = create_scheduler_with_priority()
17491752

@@ -1811,7 +1814,7 @@ def test_priority_scheduling_with_limited_slots():
18111814

18121815

18131816
def test_priority_scheduling_heap_property():
1814-
"""Test that the waiting queue maintains heap
1817+
"""Test that the waiting queue maintains heap
18151818
property for priority scheduling."""
18161819
scheduler = create_scheduler_with_priority(
18171820
max_num_seqs=1, # Only one request can run at a time
@@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property():
18571860
# Verify requests were scheduled in priority order (lowest value first)
18581861
expected_priorities = sorted(priorities)
18591862
assert scheduled_priorities == expected_priorities
1863+
1864+
1865+
def test_schedule_skip_tokenizer_init():
1866+
scheduler = create_scheduler(skip_tokenizer_init=True)
1867+
requests = create_requests(num_requests=5)
1868+
for request in requests:
1869+
scheduler.add_request(request)
1870+
output = scheduler.schedule()
1871+
assert len(output.scheduled_new_reqs) == len(requests)
1872+
assert output.grammar_bitmask is None
1873+
1874+
1875+
def test_schedule_skip_tokenizer_init_structured_output_request():
1876+
scheduler = create_scheduler(skip_tokenizer_init=True)
1877+
guided_params = GuidedDecodingParams(regex="[0-9]+")
1878+
sampling_params = SamplingParams(
1879+
ignore_eos=False,
1880+
max_tokens=16,
1881+
guided_decoding=guided_params,
1882+
)
1883+
request = Request(
1884+
request_id="0",
1885+
prompt_token_ids=[0, 1],
1886+
multi_modal_inputs=None,
1887+
multi_modal_hashes=None,
1888+
multi_modal_placeholders=None,
1889+
sampling_params=sampling_params,
1890+
pooling_params=None,
1891+
eos_token_id=EOS_TOKEN_ID,
1892+
structured_output_request=StructuredOutputRequest(sampling_params),
1893+
)
1894+
scheduler.add_request(request)
1895+
output = scheduler.schedule()
1896+
assert len(output.scheduled_new_reqs) == 0
1897+
assert len(scheduler.running) == 0
1898+
assert len(scheduler.waiting) == 1

tests/v1/engine/test_llm_engine.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
34

45
import random
5-
from typing import Optional
6+
from typing import TYPE_CHECKING, Optional
67

78
import pytest
89

9-
from vllm import LLM, SamplingParams
10+
from vllm import LLM
11+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1012
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
1113

14+
if TYPE_CHECKING:
15+
from tests.conftest import VllmRunner
16+
1217
MODEL = "facebook/opt-125m"
1318
DTYPE = "half"
1419

1520

16-
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
21+
def _vllm_model(
22+
apc: bool,
23+
vllm_runner: type[VllmRunner],
24+
monkeypatch: pytest.MonkeyPatch,
25+
*,
26+
skip_tokenizer_init: bool = False,
27+
):
1728
"""Set up VllmRunner instance."""
1829
monkeypatch.setenv("VLLM_USE_V1", "1")
1930
return vllm_runner(
@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
2334
enforce_eager=True,
2435
enable_prefix_caching=apc,
2536
gpu_memory_utilization=0.5,
37+
skip_tokenizer_init=skip_tokenizer_init,
2638
)
2739

2840

@@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
4557
yield vllm_model
4658

4759

60+
@pytest.fixture(
61+
# Function scope decouples tests & allows
62+
# env var adjustment via monkeypatch
63+
scope="function",
64+
# Prefix caching
65+
params=[False, True])
66+
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
67+
"""VllmRunner test fixture with APC."""
68+
with _vllm_model(
69+
request.param,
70+
vllm_runner,
71+
monkeypatch,
72+
skip_tokenizer_init=True,
73+
) as vllm_model:
74+
yield vllm_model
75+
76+
4877
def _get_test_sampling_params(
4978
prompt_list: list[str],
5079
seed: Optional[int] = 42,
80+
structured_outputs: bool = False,
5181
) -> tuple[list[SamplingParams], list[int]]:
5282
"""Generate random sampling params for a batch."""
5383

@@ -62,14 +92,34 @@ def get_mostly_n_gt1() -> int:
6292
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
6393
# High temperature to maximize the chance of unique completions
6494
return [
65-
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
66-
for n in n_list
95+
SamplingParams(
96+
temperature=0.95,
97+
top_p=0.95,
98+
n=n,
99+
seed=seed,
100+
guided_decoding=GuidedDecodingParams(
101+
regex="[0-9]+") if structured_outputs else None,
102+
) for n in n_list
67103
], n_list
68104

69105

106+
def test_compatibility_with_skip_tokenizer_init(
107+
vllm_model_skip_tokenizer_init: VllmRunner,
108+
example_prompts: list[str],
109+
):
110+
# Case 1: Structured output request should raise an error.
111+
sampling_params_list, _ = _get_test_sampling_params(
112+
example_prompts,
113+
structured_outputs=True,
114+
)
115+
model: LLM = vllm_model_skip_tokenizer_init.model
116+
with pytest.raises(ValueError):
117+
_ = model.generate(example_prompts, sampling_params_list)
118+
119+
70120
def test_parallel_sampling(vllm_model, example_prompts) -> None:
71121
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
72-
122+
73123
Args:
74124
vllm_model: VllmRunner instance under test.
75125
example_prompt: test fixture providing prompts for testing.

vllm/v1/engine/processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
152152
if not params.guided_decoding or not self.decoding_config:
153153
return
154154

155+
if self.model_config.skip_tokenizer_init and params.guided_decoding:
156+
raise ValueError(
157+
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
158+
)
159+
155160
engine_level_backend = self.decoding_config.backend
156161
if params.guided_decoding.backend:
157162
# Request-level backend selection is not supported in V1.

vllm/v1/structured_output/__init__.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,25 @@ def __init__(self, vllm_config: VllmConfig):
4040
self._grammar_bitmask: Optional[torch.Tensor] = None
4141
self._full_mask = torch.tensor(-1, dtype=torch.int32)
4242

43-
# The default max_workers if not specified is the number of CPUs * 5,
44-
# which is way too high since these tasks are CPU-bound, not I/O bound.
45-
# We also know we would never dominate CPU usage with just grammar
46-
# compilation, so we set it to half the number of CPUs.
47-
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
48-
self.executor = ThreadPoolExecutor(max_workers=max_workers)
49-
self.tokenizer = init_tokenizer_from_configs(
50-
model_config=self.vllm_config.model_config,
51-
scheduler_config=self.vllm_config.scheduler_config,
52-
lora_config=self.vllm_config.lora_config,
53-
).get_lora_tokenizer(None)
54-
reasoning_backend = vllm_config.decoding_config.reasoning_backend
55-
if reasoning_backend:
56-
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
57-
reasoning_backend)
58-
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
43+
if not self.vllm_config.model_config.skip_tokenizer_init:
44+
# The default max_workers if not specified is the number of
45+
# CPUs * 5, which is way too high since these tasks are CPU-bound,
46+
# not I/O bound. We also know we would never dominate CPU usage
47+
# with just grammar compilation, so we set it to half the number
48+
# of CPUs.
49+
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
50+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
51+
self.tokenizer = init_tokenizer_from_configs(
52+
model_config=self.vllm_config.model_config,
53+
scheduler_config=self.vllm_config.scheduler_config,
54+
lora_config=self.vllm_config.lora_config,
55+
).get_lora_tokenizer(None)
56+
reasoning_backend = \
57+
self.vllm_config.decoding_config.reasoning_backend
58+
if reasoning_backend:
59+
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
60+
reasoning_backend)
61+
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
5962

6063
def grammar_init(self, request: Request) -> None:
6164
if request.structured_output_request is None:

0 commit comments

Comments
 (0)