1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ from __future__ import annotations
3
4
4
5
import random
5
- from typing import Optional
6
+ from typing import TYPE_CHECKING , Optional
6
7
7
8
import pytest
8
9
9
- from vllm import LLM , SamplingParams
10
+ from vllm import LLM
11
+ from vllm .sampling_params import GuidedDecodingParams , SamplingParams
10
12
from vllm .v1 .metrics .reader import Counter , Gauge , Histogram , Metric , Vector
11
13
14
+ if TYPE_CHECKING :
15
+ from tests .conftest import VllmRunner
16
+
12
17
MODEL = "facebook/opt-125m"
13
18
DTYPE = "half"
14
19
15
20
16
- def _vllm_model (apc : bool , vllm_runner , monkeypatch ):
21
+ def _vllm_model (
22
+ apc : bool ,
23
+ vllm_runner : VllmRunner ,
24
+ monkeypatch : pytest .MonkeyPatch ,
25
+ * ,
26
+ skip_tokenizer_init : bool = False ,
27
+ ):
17
28
"""Set up VllmRunner instance."""
18
29
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
19
30
return vllm_runner (
@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
23
34
enforce_eager = True ,
24
35
enable_prefix_caching = apc ,
25
36
gpu_memory_utilization = 0.5 ,
37
+ skip_tokenizer_init = skip_tokenizer_init ,
26
38
)
27
39
28
40
@@ -45,9 +57,25 @@ def vllm_model_apc(vllm_runner, monkeypatch):
45
57
yield vllm_model
46
58
47
59
60
+ @pytest .fixture (
61
+ # Function scope decouples tests & allows
62
+ # env var adjustment via monkeypatch
63
+ scope = "function" ,
64
+ # Prefix caching
65
+ params = [False , True ])
66
+ def vllm_model_skip_tokenizer_init (vllm_runner , request , monkeypatch ):
67
+ """VllmRunner test fixture with APC."""
68
+ with _vllm_model (request .param ,
69
+ vllm_runner ,
70
+ monkeypatch ,
71
+ skip_tokenizer_init = True ) as vllm_model :
72
+ yield vllm_model
73
+
74
+
48
75
def _get_test_sampling_params (
49
76
prompt_list : list [str ],
50
77
seed : Optional [int ] = 42 ,
78
+ structured_outputs : bool = False ,
51
79
) -> tuple [list [SamplingParams ], list [int ]]:
52
80
"""Generate random sampling params for a batch."""
53
81
@@ -62,14 +90,45 @@ def get_mostly_n_gt1() -> int:
62
90
n_list = [get_mostly_n_gt1 () for _ in range (len (prompt_list ))]
63
91
# High temperature to maximize the chance of unique completions
64
92
return [
65
- SamplingParams (temperature = 0.95 , top_p = 0.95 , n = n , seed = seed )
66
- for n in n_list
93
+ SamplingParams (
94
+ temperature = 0.95 ,
95
+ top_p = 0.95 ,
96
+ n = n ,
97
+ seed = seed ,
98
+ guided_decoding = GuidedDecodingParams (
99
+ regex = "[0-9]+" ) if structured_outputs else None ,
100
+ ) for n in n_list
67
101
], n_list
68
102
69
103
104
+ def test_compatibility_with_skip_tokenizer_init (
105
+ vllm_model_skip_tokenizer_init : VllmRunner ,
106
+ example_prompts : list [str ],
107
+ ):
108
+ # Case 1: Structured output request should raise an error.
109
+ sampling_params_list , _ = _get_test_sampling_params (
110
+ example_prompts ,
111
+ structured_outputs = True ,
112
+ )
113
+ model : LLM = vllm_model_skip_tokenizer_init .model
114
+ with pytest .raises (ValueError ):
115
+ _ = model .generate (example_prompts , sampling_params_list )
116
+
117
+ # Case 2: Standard generation without structured outputs should succeed.
118
+ sampling_params_list , n_list = _get_test_sampling_params (
119
+ example_prompts ,
120
+ structured_outputs = False ,
121
+ )
122
+ outputs = model .generate (example_prompts , sampling_params_list )
123
+
124
+ # Basic sanity checks similar to parallel sampling test
125
+ for out , n in zip (outputs , n_list ):
126
+ assert len (out .outputs ) == n
127
+
128
+
70
129
def test_parallel_sampling (vllm_model , example_prompts ) -> None :
71
130
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
72
-
131
+
73
132
Args:
74
133
vllm_model: VllmRunner instance under test.
75
134
example_prompt: test fixture providing prompts for testing.
0 commit comments