|
28 | 28 | from tests.conftest import VllmRunner
|
29 | 29 |
|
30 | 30 | os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
31 |
| -MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" |
32 |
| -GuidedDecodingBackendV0 = [ |
33 |
| - "outlines", |
34 |
| - "lm-format-enforcer", |
35 |
| - "xgrammar", |
36 |
| -] |
37 |
| -GuidedDecodingBackendV1 = ["xgrammar", "guidance:disable-any-whitespace"] |
| 31 | +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" |
| 32 | + |
| 33 | +GuidedDecodingBackendV0 = ["outlines", "lm-format-enforcer", "xgrammar"] |
| 34 | +GuidedDecodingBackendV1 = ["xgrammar", "guidance"] |
38 | 35 | GuidedDecodingBackend = list(
|
39 | 36 | set(GuidedDecodingBackendV0 + GuidedDecodingBackendV1))
|
40 | 37 |
|
@@ -87,26 +84,25 @@ def sample_json_schema():
|
87 | 84 | }
|
88 | 85 |
|
89 | 86 |
|
90 |
| -@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) |
91 |
| -def test_guided_json_completion(guided_decoding_backend: str, |
92 |
| - sample_json_schema): |
93 |
| - if guided_decoding_backend == "xgrammar": |
94 |
| - # xgrammar does not support json schema, will fall back to outlines, skip it |
95 |
| - pytest.skip( |
96 |
| - f"{guided_decoding_backend} will fall back to outlines, skip it") |
| 87 | +def check_backend(guided_decoding_backend: str): |
97 | 88 | if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv(
|
98 | 89 | "VLLM_USE_V1") == "0":
|
99 |
| - # guidance does not support on v0, skip it |
100 |
| - pytest.skip( |
101 |
| - f"{guided_decoding_backend} does not support on v0, skip it") |
| 90 | + pytest.skip(f"{guided_decoding_backend} does not support v0, skip it.") |
102 | 91 | if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv(
|
103 | 92 | "VLLM_USE_V1") == "1":
|
104 |
| - pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") |
| 93 | + pytest.skip(f"{guided_decoding_backend} does not support v1, skip it.") |
| 94 | + |
| 95 | + |
| 96 | +@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) |
| 97 | +def test_guided_json_completion(guided_decoding_backend: str, |
| 98 | + sample_json_schema): |
| 99 | + check_backend(guided_decoding_backend) |
105 | 100 |
|
106 | 101 | sampling_params = SamplingParams(
|
107 | 102 | temperature=1.0,
|
108 |
| - max_tokens=1000, |
| 103 | + max_tokens=500, |
109 | 104 | guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
| 105 | + |
110 | 106 | with VllmRunner(
|
111 | 107 | MODEL_NAME,
|
112 | 108 | seed=0,
|
@@ -138,19 +134,13 @@ def test_guided_json_completion(guided_decoding_backend: str,
|
138 | 134 |
|
139 | 135 | @pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend)
|
140 | 136 | def test_guided_regex(guided_decoding_backend: str, sample_regex):
|
141 |
| - if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( |
142 |
| - "VLLM_USE_V1") == "0": |
143 |
| - # guidance does not support on v0, skip it |
144 |
| - pytest.skip( |
145 |
| - f"{guided_decoding_backend} does not support on v0, skip it") |
146 |
| - if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( |
147 |
| - "VLLM_USE_V1") == "1": |
148 |
| - pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") |
| 137 | + check_backend(guided_decoding_backend) |
| 138 | + |
| 139 | + sampling_params = SamplingParams( |
| 140 | + temperature=0.8, |
| 141 | + top_p=0.95, |
| 142 | + guided_decoding=GuidedDecodingParams(regex=sample_regex)) |
149 | 143 |
|
150 |
| - sampling_params = SamplingParams(temperature=0.8, |
151 |
| - top_p=0.95, |
152 |
| - guided_decoding=GuidedDecodingParams( |
153 |
| - regex=sample_regex, )) |
154 | 144 | with VllmRunner(
|
155 | 145 | MODEL_NAME,
|
156 | 146 | seed=0,
|
|
0 commit comments