diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 652a556659f..02d2c83ab15 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -17,6 +17,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.request import StructuredOutputRequest EOS_TOKEN_ID = 50256 @@ -33,6 +34,7 @@ def create_scheduler( block_size: int = 16, max_model_len: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + skip_tokenizer_init: bool = False, ) -> Scheduler: '''Create scheduler under test. @@ -65,6 +67,7 @@ def create_scheduler( trust_remote_code=True, dtype="float16", seed=42, + skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC kwargs_cache = ({} if enable_prefix_caching is None else { @@ -186,7 +189,7 @@ def test_get_num_unfinished_requests(): ]) def test_schedule(enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int]): - '''Test scheduling. + '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs ''' scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) @@ -1408,7 +1411,7 @@ def create_requests_with_priority( def test_priority_scheduling_basic_ordering(): - """Test that requests are scheduled in priority order + """Test that requests are scheduled in priority order (lower value = higher priority).""" scheduler = create_scheduler_with_priority() @@ -1437,7 +1440,7 @@ def test_priority_scheduling_basic_ordering(): def test_priority_scheduling_arrival_time_tiebreaker(): - """Test that arrival time is used + """Test that arrival time is used as tiebreaker when priorities are equal.""" scheduler = create_scheduler_with_priority() @@ -1495,7 +1498,7 @@ def test_priority_scheduling_mixed_priority_and_arrival(): def test_priority_scheduling_preemption(): - """Test that priority scheduling preempts + """Test that priority scheduling preempts lower priority requests when memory is constrained.""" # Create scheduler with very limited memory to force preemption scheduler = create_scheduler_with_priority( @@ -1576,7 +1579,7 @@ def test_priority_scheduling_preemption(): def test_priority_scheduling_no_preemption_when_space_available(): - """Test that preemption doesn't happen + """Test that preemption doesn't happen when there's space for new requests.""" scheduler = create_scheduler_with_priority( max_num_seqs=3, # Allow 3 concurrent requests @@ -1626,7 +1629,7 @@ def test_priority_scheduling_no_preemption_when_space_available(): def test_priority_scheduling_preemption_victim_selection(): - """Test that the correct victim is selected for + """Test that the correct victim is selected for preemption based on priority and arrival time.""" # This test verifies the priority-based victim selection logic # by checking the waiting queue order after adding requests with different @@ -1743,7 +1746,7 @@ def test_priority_scheduling_waiting_queue_order(): def test_priority_scheduling_fcfs_fallback(): - """Test that FCFS behavior is maintained when all + """Test that FCFS behavior is maintained when all requests have same priority.""" scheduler = create_scheduler_with_priority() @@ -1811,7 +1814,7 @@ def test_priority_scheduling_with_limited_slots(): def test_priority_scheduling_heap_property(): - """Test that the waiting queue maintains heap + """Test that the waiting queue maintains heap property for priority scheduling.""" scheduler = create_scheduler_with_priority( max_num_seqs=1, # Only one request can run at a time @@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property(): # Verify requests were scheduled in priority order (lowest value first) expected_priorities = sorted(priorities) assert scheduled_priorities == expected_priorities + + +def test_schedule_skip_tokenizer_init(): + scheduler = create_scheduler(skip_tokenizer_init=True) + requests = create_requests(num_requests=5) + for request in requests: + scheduler.add_request(request) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert output.grammar_bitmask is None + + +def test_schedule_skip_tokenizer_init_structured_output_request(): + scheduler = create_scheduler(skip_tokenizer_init=True) + guided_params = GuidedDecodingParams(regex="[0-9]+") + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=16, + guided_decoding=guided_params, + ) + request = Request( + request_id="0", + prompt_token_ids=[0, 1], + multi_modal_inputs=None, + multi_modal_hashes=None, + multi_modal_placeholders=None, + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + structured_output_request=StructuredOutputRequest(sampling_params), + ) + scheduler.add_request(request) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 6284dcfb915..059106c62a2 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,19 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations import random -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest -from vllm import LLM, SamplingParams +from vllm import LLM +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector +if TYPE_CHECKING: + from tests.conftest import VllmRunner + MODEL = "facebook/opt-125m" DTYPE = "half" -def _vllm_model(apc: bool, vllm_runner, monkeypatch): +def _vllm_model( + apc: bool, + vllm_runner: type[VllmRunner], + monkeypatch: pytest.MonkeyPatch, + *, + skip_tokenizer_init: bool = False, +): """Set up VllmRunner instance.""" monkeypatch.setenv("VLLM_USE_V1", "1") return vllm_runner( @@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch): enforce_eager=True, enable_prefix_caching=apc, gpu_memory_utilization=0.5, + skip_tokenizer_init=skip_tokenizer_init, ) @@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch): yield vllm_model +@pytest.fixture( + # Function scope decouples tests & allows + # env var adjustment via monkeypatch + scope="function", + # Prefix caching + params=[False, True]) +def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch): + """VllmRunner test fixture with APC.""" + with _vllm_model( + request.param, + vllm_runner, + monkeypatch, + skip_tokenizer_init=True, + ) as vllm_model: + yield vllm_model + + def _get_test_sampling_params( prompt_list: list[str], seed: Optional[int] = 42, + structured_outputs: bool = False, ) -> tuple[list[SamplingParams], list[int]]: """Generate random sampling params for a batch.""" @@ -62,14 +92,34 @@ def get_mostly_n_gt1() -> int: n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] # High temperature to maximize the chance of unique completions return [ - SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) - for n in n_list + SamplingParams( + temperature=0.95, + top_p=0.95, + n=n, + seed=seed, + guided_decoding=GuidedDecodingParams( + regex="[0-9]+") if structured_outputs else None, + ) for n in n_list ], n_list +def test_compatibility_with_skip_tokenizer_init( + vllm_model_skip_tokenizer_init: VllmRunner, + example_prompts: list[str], +): + # Case 1: Structured output request should raise an error. + sampling_params_list, _ = _get_test_sampling_params( + example_prompts, + structured_outputs=True, + ) + model: LLM = vllm_model_skip_tokenizer_init.model + with pytest.raises(ValueError): + _ = model.generate(example_prompts, sampling_params_list) + + def test_parallel_sampling(vllm_model, example_prompts) -> None: """Test passes if parallel sampling `n>1` yields `n` unique completions. - + Args: vllm_model: VllmRunner instance under test. example_prompt: test fixture providing prompts for testing. diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7e7703df2cf..9fc52543efd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -152,6 +152,11 @@ def _validate_structured_output(self, params: SamplingParams) -> None: if not params.guided_decoding or not self.decoding_config: return + if self.model_config.skip_tokenizer_init and params.guided_decoding: + raise ValueError( + "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 + ) + engine_level_backend = self.decoding_config.backend if params.guided_decoding.backend: # Request-level backend selection is not supported in V1. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index c5500b9a384..839f1da8dd0 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -40,22 +40,25 @@ def __init__(self, vllm_config: VllmConfig): self._grammar_bitmask: Optional[torch.Tensor] = None self._full_mask = torch.tensor(-1, dtype=torch.int32) - # The default max_workers if not specified is the number of CPUs * 5, - # which is way too high since these tasks are CPU-bound, not I/O bound. - # We also know we would never dominate CPU usage with just grammar - # compilation, so we set it to half the number of CPUs. - max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) - self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, - ).get_lora_tokenizer(None) - reasoning_backend = vllm_config.decoding_config.reasoning_backend - if reasoning_backend: - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) - self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + if not self.vllm_config.model_config.skip_tokenizer_init: + # The default max_workers if not specified is the number of + # CPUs * 5, which is way too high since these tasks are CPU-bound, + # not I/O bound. We also know we would never dominate CPU usage + # with just grammar compilation, so we set it to half the number + # of CPUs. + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) + reasoning_backend = \ + self.vllm_config.decoding_config.reasoning_backend + if reasoning_backend: + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: