Skip to content

Commit 30ed44e

Browse files
Handle structured text with skip_tokenizer_init=True
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent d32de8a commit 30ed44e

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def schedule(self) -> SchedulerOutput:
544544
self.requests,
545545
structured_output_request_ids,
546546
scheduled_spec_decode_tokens,
547-
)
547+
) if structured_output_request_ids else None
548548
# Construct the scheduler output.
549549
new_reqs_data = [
550550
NewRequestData.from_request(req,
@@ -826,8 +826,9 @@ def update_from_output(
826826
# the outer lists can be of length > 1.
827827
new_logprobs = logprobs.slice(req_index, req_index + 1)
828828

829-
if new_token_ids and self.structured_output_manager.should_advance(
830-
request):
829+
if new_token_ids and self.structured_output_manager \
830+
and self.structured_output_manager.should_advance(
831+
request):
831832
# NOTE: structured_output_request
832833
# should not be None if use_structured_output, we have
833834
# check above, so safe to ignore type warning
@@ -840,7 +841,8 @@ def update_from_output(
840841

841842
# Add newly generated spec token ids to the request.
842843
if spec_token_ids is not None:
843-
if self.structured_output_manager.should_advance(request):
844+
if self.structured_output_manager \
845+
and self.structured_output_manager.should_advance(request):
844846
metadata = request.structured_output_request
845847
# Needs to happen after new_token_ids are accepted.
846848
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]

vllm/v1/engine/core.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def __init__(self,
8686
self.collective_rpc("initialize_cache",
8787
args=(num_gpu_blocks, num_cpu_blocks))
8888

89-
self.structured_output_manager = StructuredOutputManager(vllm_config)
89+
if vllm_config.model_config.skip_tokenizer_init:
90+
# Structured output generation requires a tokenizer
91+
self.structured_output_manager = None
92+
else:
93+
self.structured_output_manager = StructuredOutputManager(vllm_config)
9094

9195
# Setup scheduler.
9296
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
@@ -198,7 +202,11 @@ def add_request(self, request: EngineCoreRequest):
198202
request.mm_inputs, request.mm_hashes)
199203

200204
req = Request.from_engine_core_request(request)
201-
if req.use_structured_output:
205+
if req.use_structured_output and self.structured_output_manager:
206+
# We check for `structured_output_manager` because
207+
# a StructuredOutputManager is not instantiated if a tokenizer
208+
# is not initialized for the model.
209+
202210
# Start grammar compilation asynchronously
203211
self.structured_output_manager.grammar_init(req)
204212

@@ -299,7 +307,8 @@ def step_with_batch_queue(
299307
return engine_core_outputs, scheduled_batch
300308

301309
def shutdown(self):
302-
self.structured_output_manager.clear_backend()
310+
if self.structured_output_manager:
311+
self.structured_output_manager.clear_backend()
303312
if self.model_executor:
304313
self.model_executor.shutdown()
305314
if self.scheduler:

vllm/v1/structured_output/__init__.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ def __init__(self, vllm_config: VllmConfig):
4646
# compilation, so we set it to half the number of CPUs.
4747
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
4848
self.executor = ThreadPoolExecutor(max_workers=max_workers)
49-
self.tokenizer = None if vllm_config.model_config.skip_tokenizer_init \
50-
else init_tokenizer_from_configs(
51-
model_config=self.vllm_config.model_config,
52-
scheduler_config=self.vllm_config.scheduler_config,
53-
lora_config=self.vllm_config.lora_config,
54-
).get_lora_tokenizer(None)
49+
self.tokenizer = init_tokenizer_from_configs(
50+
model_config=self.vllm_config.model_config,
51+
scheduler_config=self.vllm_config.scheduler_config,
52+
lora_config=self.vllm_config.lora_config,
53+
).get_lora_tokenizer(None)
5554
reasoning_backend = vllm_config.decoding_config.reasoning_backend
5655
if reasoning_backend:
5756
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
@@ -116,8 +115,6 @@ def grammar_bitmask(
116115
scheduled_spec_decode_tokens: dict[str, list[int]],
117116
) -> Optional[npt.NDArray[np.int32]]:
118117
# Prepare the structured output bitmask for this batch.
119-
if not structured_output_request_ids:
120-
return None
121118

122119
max_num_spec_tokens = 0
123120
if self.vllm_config.speculative_config is not None:

0 commit comments

Comments
 (0)