From 8a99739e177256752ce5717d76d46b731efbf628 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 2 Jul 2025 05:05:18 -0400 Subject: [PATCH 1/7] chore(so): support skip_tokenizer_init Signed-off-by: Aaron Pham --- tests/v1/core/test_scheduler.py | 60 ++++++++++++++++---- tests/v1/kv_connector/unit/utils.py | 3 +- vllm/v1/core/sched/scheduler.py | 6 +- vllm/v1/engine/core.py | 3 +- vllm/v1/engine/processor.py | 5 ++ vllm/v1/structured_output/__init__.py | 81 ++++++++++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 4 +- 7 files changed, 117 insertions(+), 45 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 652a556659f..79199ca07fb 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -17,6 +17,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.request import StructuredOutputRequest EOS_TOKEN_ID = 50256 @@ -33,6 +34,7 @@ def create_scheduler( block_size: int = 16, max_model_len: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + skip_tokenizer_init: bool = False, ) -> Scheduler: '''Create scheduler under test. @@ -65,6 +67,7 @@ def create_scheduler( trust_remote_code=True, dtype="float16", seed=42, + skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC kwargs_cache = ({} if enable_prefix_caching is None else { @@ -109,7 +112,8 @@ def create_scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), + structured_output_manager=StructuredOutputManager( + vllm_config=vllm_config), ) @@ -186,7 +190,7 @@ def test_get_num_unfinished_requests(): ]) def test_schedule(enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int]): - '''Test scheduling. + '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs ''' scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) @@ -1408,7 +1412,7 @@ def create_requests_with_priority( def test_priority_scheduling_basic_ordering(): - """Test that requests are scheduled in priority order + """Test that requests are scheduled in priority order (lower value = higher priority).""" scheduler = create_scheduler_with_priority() @@ -1437,7 +1441,7 @@ def test_priority_scheduling_basic_ordering(): def test_priority_scheduling_arrival_time_tiebreaker(): - """Test that arrival time is used + """Test that arrival time is used as tiebreaker when priorities are equal.""" scheduler = create_scheduler_with_priority() @@ -1495,7 +1499,7 @@ def test_priority_scheduling_mixed_priority_and_arrival(): def test_priority_scheduling_preemption(): - """Test that priority scheduling preempts + """Test that priority scheduling preempts lower priority requests when memory is constrained.""" # Create scheduler with very limited memory to force preemption scheduler = create_scheduler_with_priority( @@ -1576,7 +1580,7 @@ def test_priority_scheduling_preemption(): def test_priority_scheduling_no_preemption_when_space_available(): - """Test that preemption doesn't happen + """Test that preemption doesn't happen when there's space for new requests.""" scheduler = create_scheduler_with_priority( max_num_seqs=3, # Allow 3 concurrent requests @@ -1626,7 +1630,7 @@ def test_priority_scheduling_no_preemption_when_space_available(): def test_priority_scheduling_preemption_victim_selection(): - """Test that the correct victim is selected for + """Test that the correct victim is selected for preemption based on priority and arrival time.""" # This test verifies the priority-based victim selection logic # by checking the waiting queue order after adding requests with different @@ -1743,7 +1747,7 @@ def test_priority_scheduling_waiting_queue_order(): def test_priority_scheduling_fcfs_fallback(): - """Test that FCFS behavior is maintained when all + """Test that FCFS behavior is maintained when all requests have same priority.""" scheduler = create_scheduler_with_priority() @@ -1811,7 +1815,7 @@ def test_priority_scheduling_with_limited_slots(): def test_priority_scheduling_heap_property(): - """Test that the waiting queue maintains heap + """Test that the waiting queue maintains heap property for priority scheduling.""" scheduler = create_scheduler_with_priority( max_num_seqs=1, # Only one request can run at a time @@ -1857,3 +1861,39 @@ def test_priority_scheduling_heap_property(): # Verify requests were scheduled in priority order (lowest value first) expected_priorities = sorted(priorities) assert scheduled_priorities == expected_priorities + + +def test_schedule_skip_tokenizer_init(): + scheduler = create_scheduler(skip_tokenizer_init=True) + requests = create_requests(num_requests=5) + for request in requests: + scheduler.add_request(request) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert output.grammar_bitmask is None + + +def test_schedule_skip_tokenizer_init_structured_output_request(): + scheduler = create_scheduler(skip_tokenizer_init=True) + guided_params = GuidedDecodingParams(regex="[0-9]+") + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=16, + guided_decoding=guided_params, + ) + request = Request( + request_id="0", + prompt_token_ids=[0, 1], + multi_modal_inputs=None, + multi_modal_hashes=None, + multi_modal_placeholders=None, + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + structured_output_request=StructuredOutputRequest(sampling_params), + ) + scheduler.add_request(request) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 983d900606f..e0e3de3fe4a 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -107,7 +107,8 @@ def create_scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), + structured_output_manager=StructuredOutputManager( + vllm_config=vllm_config), ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2..8ed9858b33d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,7 +7,7 @@ import time from collections import defaultdict from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -33,7 +33,9 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats -from vllm.v1.structured_output import StructuredOutputManager + +if TYPE_CHECKING: + from vllm.v1.structured_output import StructuredOutputManager logger = init_logger(__name__) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 453ed364dc8..ef2252457fe 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -86,7 +86,8 @@ def __init__(self, self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) - self.structured_output_manager = StructuredOutputManager(vllm_config) + self.structured_output_manager = StructuredOutputManager( + vllm_config=vllm_config) # Setup scheduler. if isinstance(vllm_config.scheduler_config.scheduler_cls, str): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7e7703df2cf..49c2a502c2a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -152,6 +152,11 @@ def _validate_structured_output(self, params: SamplingParams) -> None: if not params.guided_decoding or not self.decoding_config: return + if self.model_config.skip_tokenizer_init and self.decoding_config: + raise ValueError( + "'skip_tokenizer_init' is specified during engine startup. This implies that the model doesn't contain sufficient files to setup tokenizers, which structured outputs requires tokenizers to work. Specifying structured outputs parameters will not be supported in conjunction with 'skip_tokenizer_init'." # noqa: E501 + ) + engine_level_backend = self.decoding_config.backend if params.guided_decoding.backend: # Request-level backend selection is not supported in V1. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index c5500b9a384..c5c8d7c6a87 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -6,6 +6,9 @@ from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Optional +from pydantic import ConfigDict, Field +from pydantic.dataclasses import dataclass + from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager @@ -29,36 +32,56 @@ logger = init_logger(__name__) +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class StructuredOutputManager: """Engine-level manager for structured output requests.""" - - def __init__(self, vllm_config: VllmConfig): - self.backend: Optional[StructuredOutputBackend] = None - self.reasoner: Optional[ReasoningParser] = None - self.vllm_config = vllm_config - - self._grammar_bitmask: Optional[torch.Tensor] = None - self._full_mask = torch.tensor(-1, dtype=torch.int32) - - # The default max_workers if not specified is the number of CPUs * 5, - # which is way too high since these tasks are CPU-bound, not I/O bound. - # We also know we would never dominate CPU usage with just grammar - # compilation, so we set it to half the number of CPUs. - max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) - self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, - ).get_lora_tokenizer(None) - reasoning_backend = vllm_config.decoding_config.reasoning_backend - if reasoning_backend: - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) - self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + vllm_config: VllmConfig + + backend: Optional[StructuredOutputBackend] = Field( + default=None, + init=False, + repr=False, + ) + reasoner: Optional[ReasoningParser] = Field( + default=None, + init=False, + repr=False, + ) + _grammar_bitmask: Optional[torch.Tensor] = Field( + default=None, + init=False, + repr=False, + ) + _full_mask: torch.Tensor = Field( + default_factory=lambda: torch.tensor(-1, dtype=torch.int32), + init=False, + repr=False, + ) + + def __post_init__(self): + if not self.vllm_config.model_config.skip_tokenizer_init: + # The default max_workers if not specified is the number + # of CPUs * 5, which is way too high since these tasks are + # CPU-bound, not I/O bound. We also know we would never dominate + # CPU usage with just grammar compilation, so we set it to half + # the number of CPUs. + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) + reasoning_backend = \ + self.vllm_config.decoding_config.reasoning_backend + if reasoning_backend: + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: - if request.structured_output_request is None: + if request.structured_output_request is None or \ + self.vllm_config.model_config.skip_tokenizer_init: return if TYPE_CHECKING: @@ -115,7 +138,8 @@ def grammar_bitmask( scheduled_spec_decode_tokens: dict[str, list[int]], ) -> Optional[npt.NDArray[np.int32]]: # Prepare the structured output bitmask for this batch. - if not structured_output_request_ids: + if not structured_output_request_ids \ + or self.vllm_config.model_config.skip_tokenizer_init: return None max_num_spec_tokens = 0 @@ -193,7 +217,8 @@ def grammar_bitmask( return bitmask_tensor.numpy() def should_advance(self, request: Request) -> bool: - if not request.use_structured_output: + if not request.use_structured_output \ + or self.vllm_config.model_config.skip_tokenizer_init: return False # To determine whether we can advance the FSM. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index df9d69006fc..95e4bb408cf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -73,13 +73,11 @@ sanity_check_mm_encoder_outputs, scatter_mm_placeholders) if TYPE_CHECKING: - import xgrammar as xgr import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput else: - xgr = LazyLoader("xgr", globals(), "xgrammar") xgr_torch_compile = LazyLoader( "xgr_torch_compile", globals(), "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") @@ -1960,7 +1958,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor): Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. This is to help balance expert-selection - during profile_run - - during DP rank dummy run + - during DP rank dummy run """ dp_size = self.vllm_config.parallel_config.data_parallel_size randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 From 45f68a8e90953829fdf37774553aaf4e2594b4dd Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 2 Jul 2025 07:26:03 -0400 Subject: [PATCH 2/7] revert: remove dataclass initialization and update warning messages Co-authored-by: Nick Hill Signed-off-by: Aaron Pham --- vllm/v1/engine/processor.py | 2 +- vllm/v1/structured_output/__init__.py | 49 ++++++++------------------- 2 files changed, 16 insertions(+), 35 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 49c2a502c2a..685aef166db 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -154,7 +154,7 @@ def _validate_structured_output(self, params: SamplingParams) -> None: if self.model_config.skip_tokenizer_init and self.decoding_config: raise ValueError( - "'skip_tokenizer_init' is specified during engine startup. This implies that the model doesn't contain sufficient files to setup tokenizers, which structured outputs requires tokenizers to work. Specifying structured outputs parameters will not be supported in conjunction with 'skip_tokenizer_init'." # noqa: E501 + "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 ) engine_level_backend = self.decoding_config.backend diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index c5c8d7c6a87..3a200a0f3e1 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -6,9 +6,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Optional -from pydantic import ConfigDict, Field -from pydantic.dataclasses import dataclass - from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager @@ -32,39 +29,23 @@ logger = init_logger(__name__) -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class StructuredOutputManager: """Engine-level manager for structured output requests.""" - vllm_config: VllmConfig - - backend: Optional[StructuredOutputBackend] = Field( - default=None, - init=False, - repr=False, - ) - reasoner: Optional[ReasoningParser] = Field( - default=None, - init=False, - repr=False, - ) - _grammar_bitmask: Optional[torch.Tensor] = Field( - default=None, - init=False, - repr=False, - ) - _full_mask: torch.Tensor = Field( - default_factory=lambda: torch.tensor(-1, dtype=torch.int32), - init=False, - repr=False, - ) - - def __post_init__(self): + + def __init__(self, vllm_config: VllmConfig): + self.backend: Optional[StructuredOutputBackend] = None + self.reasoner: Optional[ReasoningParser] = None + self.vllm_config = vllm_config + + self._grammar_bitmask: Optional[torch.Tensor] = None + self._full_mask = torch.tensor(-1, dtype=torch.int32) + if not self.vllm_config.model_config.skip_tokenizer_init: - # The default max_workers if not specified is the number - # of CPUs * 5, which is way too high since these tasks are - # CPU-bound, not I/O bound. We also know we would never dominate - # CPU usage with just grammar compilation, so we set it to half - # the number of CPUs. + # The default max_workers if not specified is the number of + # CPUs * 5, which is way too high since these tasks are CPU-bound, + # not I/O bound. We also know we would never dominate CPU usage + # with just grammar compilation, so we set it to half the number + # of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tokenizer = init_tokenizer_from_configs( @@ -73,7 +54,7 @@ def __post_init__(self): lora_config=self.vllm_config.lora_config, ).get_lora_tokenizer(None) reasoning_backend = \ - self.vllm_config.decoding_config.reasoning_backend + self.vllm_config.decoding_config.reasoning_backend if reasoning_backend: reasoner_cls = ReasoningParserManager.get_reasoning_parser( reasoning_backend) From a49f18047a6b6507cc93ee48aafaab3f480d6f53 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 2 Jul 2025 08:10:32 -0400 Subject: [PATCH 3/7] chore: address Nick's comments and add tests for v1 Signed-off-by: Aaron Pham --- tests/v1/engine/test_llm_engine.py | 71 ++++++++++++++++++++++++--- vllm/v1/engine/processor.py | 2 +- vllm/v1/structured_output/__init__.py | 9 ++-- 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 6284dcfb915..9f54ed0c573 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,19 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations import random -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest -from vllm import LLM, SamplingParams +from vllm import LLM +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector +if TYPE_CHECKING: + from tests.conftest import VllmRunner + MODEL = "facebook/opt-125m" DTYPE = "half" -def _vllm_model(apc: bool, vllm_runner, monkeypatch): +def _vllm_model( + apc: bool, + vllm_runner: VllmRunner, + monkeypatch: pytest.MonkeyPatch, + *, + skip_tokenizer_init: bool = False, +): """Set up VllmRunner instance.""" monkeypatch.setenv("VLLM_USE_V1", "1") return vllm_runner( @@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch): enforce_eager=True, enable_prefix_caching=apc, gpu_memory_utilization=0.5, + skip_tokenizer_init=skip_tokenizer_init, ) @@ -45,9 +57,25 @@ def vllm_model_apc(vllm_runner, monkeypatch): yield vllm_model +@pytest.fixture( + # Function scope decouples tests & allows + # env var adjustment via monkeypatch + scope="function", + # Prefix caching + params=[False, True]) +def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch): + """VllmRunner test fixture with APC.""" + with _vllm_model(request.param, + vllm_runner, + monkeypatch, + skip_tokenizer_init=True) as vllm_model: + yield vllm_model + + def _get_test_sampling_params( prompt_list: list[str], seed: Optional[int] = 42, + structured_outputs: bool = False, ) -> tuple[list[SamplingParams], list[int]]: """Generate random sampling params for a batch.""" @@ -62,14 +90,45 @@ def get_mostly_n_gt1() -> int: n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] # High temperature to maximize the chance of unique completions return [ - SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) - for n in n_list + SamplingParams( + temperature=0.95, + top_p=0.95, + n=n, + seed=seed, + guided_decoding=GuidedDecodingParams( + regex="[0-9]+") if structured_outputs else None, + ) for n in n_list ], n_list +def test_compatibility_with_skip_tokenizer_init( + vllm_model_skip_tokenizer_init: VllmRunner, + example_prompts: list[str], +): + # Case 1: Structured output request should raise an error. + sampling_params_list, _ = _get_test_sampling_params( + example_prompts, + structured_outputs=True, + ) + model: LLM = vllm_model_skip_tokenizer_init.model + with pytest.raises(ValueError): + _ = model.generate(example_prompts, sampling_params_list) + + # Case 2: Standard generation without structured outputs should succeed. + sampling_params_list, n_list = _get_test_sampling_params( + example_prompts, + structured_outputs=False, + ) + outputs = model.generate(example_prompts, sampling_params_list) + + # Basic sanity checks similar to parallel sampling test + for out, n in zip(outputs, n_list): + assert len(out.outputs) == n + + def test_parallel_sampling(vllm_model, example_prompts) -> None: """Test passes if parallel sampling `n>1` yields `n` unique completions. - + Args: vllm_model: VllmRunner instance under test. example_prompt: test fixture providing prompts for testing. diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 685aef166db..9fc52543efd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -152,7 +152,7 @@ def _validate_structured_output(self, params: SamplingParams) -> None: if not params.guided_decoding or not self.decoding_config: return - if self.model_config.skip_tokenizer_init and self.decoding_config: + if self.model_config.skip_tokenizer_init and params.guided_decoding: raise ValueError( "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3a200a0f3e1..839f1da8dd0 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -61,8 +61,7 @@ def __init__(self, vllm_config: VllmConfig): self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: - if request.structured_output_request is None or \ - self.vllm_config.model_config.skip_tokenizer_init: + if request.structured_output_request is None: return if TYPE_CHECKING: @@ -119,8 +118,7 @@ def grammar_bitmask( scheduled_spec_decode_tokens: dict[str, list[int]], ) -> Optional[npt.NDArray[np.int32]]: # Prepare the structured output bitmask for this batch. - if not structured_output_request_ids \ - or self.vllm_config.model_config.skip_tokenizer_init: + if not structured_output_request_ids: return None max_num_spec_tokens = 0 @@ -198,8 +196,7 @@ def grammar_bitmask( return bitmask_tensor.numpy() def should_advance(self, request: Request) -> bool: - if not request.use_structured_output \ - or self.vllm_config.model_config.skip_tokenizer_init: + if not request.use_structured_output: return False # To determine whether we can advance the FSM. From 7d2ad08a339bb1b206dbc7606bc02777cb73ba92 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 2 Jul 2025 08:12:39 -0400 Subject: [PATCH 4/7] revert: remove misc changes Signed-off-by: Aaron Pham --- tests/v1/core/test_scheduler.py | 3 +-- tests/v1/kv_connector/unit/utils.py | 3 +-- vllm/v1/core/sched/scheduler.py | 6 ++---- vllm/v1/engine/core.py | 3 +-- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 79199ca07fb..02d2c83ab15 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -112,8 +112,7 @@ def create_scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, - structured_output_manager=StructuredOutputManager( - vllm_config=vllm_config), + structured_output_manager=StructuredOutputManager(vllm_config), ) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index e0e3de3fe4a..983d900606f 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -107,8 +107,7 @@ def create_scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, - structured_output_manager=StructuredOutputManager( - vllm_config=vllm_config), + structured_output_manager=StructuredOutputManager(vllm_config), ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8ed9858b33d..fe552db74e2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,7 +7,7 @@ import time from collections import defaultdict from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import Any, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -33,9 +33,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats - -if TYPE_CHECKING: - from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output import StructuredOutputManager logger = init_logger(__name__) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ef2252457fe..453ed364dc8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -86,8 +86,7 @@ def __init__(self, self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) - self.structured_output_manager = StructuredOutputManager( - vllm_config=vllm_config) + self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. if isinstance(vllm_config.scheduler_config.scheduler_cls, str): From 31e40608c623ab88428863a93b4dbc46859bac88 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 2 Jul 2025 08:14:56 -0400 Subject: [PATCH 5/7] perf(test): improve test time by removing duplicates Signed-off-by: Aaron Pham --- tests/v1/engine/test_llm_engine.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 9f54ed0c573..27d3bb02ab5 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -114,17 +114,6 @@ def test_compatibility_with_skip_tokenizer_init( with pytest.raises(ValueError): _ = model.generate(example_prompts, sampling_params_list) - # Case 2: Standard generation without structured outputs should succeed. - sampling_params_list, n_list = _get_test_sampling_params( - example_prompts, - structured_outputs=False, - ) - outputs = model.generate(example_prompts, sampling_params_list) - - # Basic sanity checks similar to parallel sampling test - for out, n in zip(outputs, n_list): - assert len(out.outputs) == n - def test_parallel_sampling(vllm_model, example_prompts) -> None: """Test passes if parallel sampling `n>1` yields `n` unique completions. From 700a4b23f2ca2b05dacd685c0bfa2447a9cc6227 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 2 Jul 2025 08:15:52 -0400 Subject: [PATCH 6/7] revert: remove misc changes Signed-off-by: Aaron Pham --- vllm/v1/worker/gpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 95e4bb408cf..df9d69006fc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -73,11 +73,13 @@ sanity_check_mm_encoder_outputs, scatter_mm_placeholders) if TYPE_CHECKING: + import xgrammar as xgr import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput else: + xgr = LazyLoader("xgr", globals(), "xgrammar") xgr_torch_compile = LazyLoader( "xgr_torch_compile", globals(), "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") @@ -1958,7 +1960,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor): Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. This is to help balance expert-selection - during profile_run - - during DP rank dummy run + - during DP rank dummy run """ dp_size = self.vllm_config.parallel_config.data_parallel_size randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 From 8b4df50e8b5a17c7bfde9ff71ca2d11862351d87 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 2 Jul 2025 15:48:42 -0400 Subject: [PATCH 7/7] chore: fix precommit Signed-off-by: Aaron Pham --- tests/v1/engine/test_llm_engine.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 27d3bb02ab5..059106c62a2 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -20,7 +20,7 @@ def _vllm_model( apc: bool, - vllm_runner: VllmRunner, + vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, *, skip_tokenizer_init: bool = False, @@ -65,10 +65,12 @@ def vllm_model_apc(vllm_runner, monkeypatch): params=[False, True]) def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch): """VllmRunner test fixture with APC.""" - with _vllm_model(request.param, - vllm_runner, - monkeypatch, - skip_tokenizer_init=True) as vllm_model: + with _vllm_model( + request.param, + vllm_runner, + monkeypatch, + skip_tokenizer_init=True, + ) as vllm_model: yield vllm_model