Skip to content

Commit a49f180

Browse files
committed
chore: address Nick's comments and add tests for v1
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
1 parent 45f68a8 commit a49f180

File tree

3 files changed

+69
-13
lines changed

3 files changed

+69
-13
lines changed

tests/v1/engine/test_llm_engine.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
34

45
import random
5-
from typing import Optional
6+
from typing import TYPE_CHECKING, Optional
67

78
import pytest
89

9-
from vllm import LLM, SamplingParams
10+
from vllm import LLM
11+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1012
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
1113

14+
if TYPE_CHECKING:
15+
from tests.conftest import VllmRunner
16+
1217
MODEL = "facebook/opt-125m"
1318
DTYPE = "half"
1419

1520

16-
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
21+
def _vllm_model(
22+
apc: bool,
23+
vllm_runner: VllmRunner,
24+
monkeypatch: pytest.MonkeyPatch,
25+
*,
26+
skip_tokenizer_init: bool = False,
27+
):
1728
"""Set up VllmRunner instance."""
1829
monkeypatch.setenv("VLLM_USE_V1", "1")
1930
return vllm_runner(
@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
2334
enforce_eager=True,
2435
enable_prefix_caching=apc,
2536
gpu_memory_utilization=0.5,
37+
skip_tokenizer_init=skip_tokenizer_init,
2638
)
2739

2840

@@ -45,9 +57,25 @@ def vllm_model_apc(vllm_runner, monkeypatch):
4557
yield vllm_model
4658

4759

60+
@pytest.fixture(
61+
# Function scope decouples tests & allows
62+
# env var adjustment via monkeypatch
63+
scope="function",
64+
# Prefix caching
65+
params=[False, True])
66+
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
67+
"""VllmRunner test fixture with APC."""
68+
with _vllm_model(request.param,
69+
vllm_runner,
70+
monkeypatch,
71+
skip_tokenizer_init=True) as vllm_model:
72+
yield vllm_model
73+
74+
4875
def _get_test_sampling_params(
4976
prompt_list: list[str],
5077
seed: Optional[int] = 42,
78+
structured_outputs: bool = False,
5179
) -> tuple[list[SamplingParams], list[int]]:
5280
"""Generate random sampling params for a batch."""
5381

@@ -62,14 +90,45 @@ def get_mostly_n_gt1() -> int:
6290
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
6391
# High temperature to maximize the chance of unique completions
6492
return [
65-
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
66-
for n in n_list
93+
SamplingParams(
94+
temperature=0.95,
95+
top_p=0.95,
96+
n=n,
97+
seed=seed,
98+
guided_decoding=GuidedDecodingParams(
99+
regex="[0-9]+") if structured_outputs else None,
100+
) for n in n_list
67101
], n_list
68102

69103

104+
def test_compatibility_with_skip_tokenizer_init(
105+
vllm_model_skip_tokenizer_init: VllmRunner,
106+
example_prompts: list[str],
107+
):
108+
# Case 1: Structured output request should raise an error.
109+
sampling_params_list, _ = _get_test_sampling_params(
110+
example_prompts,
111+
structured_outputs=True,
112+
)
113+
model: LLM = vllm_model_skip_tokenizer_init.model
114+
with pytest.raises(ValueError):
115+
_ = model.generate(example_prompts, sampling_params_list)
116+
117+
# Case 2: Standard generation without structured outputs should succeed.
118+
sampling_params_list, n_list = _get_test_sampling_params(
119+
example_prompts,
120+
structured_outputs=False,
121+
)
122+
outputs = model.generate(example_prompts, sampling_params_list)
123+
124+
# Basic sanity checks similar to parallel sampling test
125+
for out, n in zip(outputs, n_list):
126+
assert len(out.outputs) == n
127+
128+
70129
def test_parallel_sampling(vllm_model, example_prompts) -> None:
71130
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
72-
131+
73132
Args:
74133
vllm_model: VllmRunner instance under test.
75134
example_prompt: test fixture providing prompts for testing.

vllm/v1/engine/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
152152
if not params.guided_decoding or not self.decoding_config:
153153
return
154154

155-
if self.model_config.skip_tokenizer_init and self.decoding_config:
155+
if self.model_config.skip_tokenizer_init and params.guided_decoding:
156156
raise ValueError(
157157
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
158158
)

vllm/v1/structured_output/__init__.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def __init__(self, vllm_config: VllmConfig):
6161
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
6262

6363
def grammar_init(self, request: Request) -> None:
64-
if request.structured_output_request is None or \
65-
self.vllm_config.model_config.skip_tokenizer_init:
64+
if request.structured_output_request is None:
6665
return
6766

6867
if TYPE_CHECKING:
@@ -119,8 +118,7 @@ def grammar_bitmask(
119118
scheduled_spec_decode_tokens: dict[str, list[int]],
120119
) -> Optional[npt.NDArray[np.int32]]:
121120
# Prepare the structured output bitmask for this batch.
122-
if not structured_output_request_ids \
123-
or self.vllm_config.model_config.skip_tokenizer_init:
121+
if not structured_output_request_ids:
124122
return None
125123

126124
max_num_spec_tokens = 0
@@ -198,8 +196,7 @@ def grammar_bitmask(
198196
return bitmask_tensor.numpy()
199197

200198
def should_advance(self, request: Request) -> bool:
201-
if not request.use_structured_output \
202-
or self.vllm_config.model_config.skip_tokenizer_init:
199+
if not request.use_structured_output:
203200
return False
204201

205202
# To determine whether we can advance the FSM.

0 commit comments

Comments
 (0)