Skip to content

Commit 5d8490d

Browse files
committed
make precommit and lint
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
1 parent f2e195a commit 5d8490d

File tree

4 files changed

+44
-31
lines changed

4 files changed

+44
-31
lines changed

vllm/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4412,7 +4412,9 @@ class ReasoningConfig:
44124412
think_end_token_id: Optional[int] = None
44134413
"""Token ID that indicates the end of reasoning."""
44144414

4415-
def __init__(self, think_start_token_id: Optional[int] = None, think_end_token_id: Optional[int] = None):
4415+
def __init__(self,
4416+
think_start_token_id: Optional[int] = None,
4417+
think_end_token_id: Optional[int] = None):
44164418
self.think_start_token_id = think_start_token_id
44174419
self.think_end_token_id = think_end_token_id
44184420

vllm/v1/sample/logits_processor.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class MoveDirectionality(Enum):
2727

2828
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
2929
# requests added to the batch.
30-
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], list[int]]
30+
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int],
31+
list[int]]
3132
# (index 1, index 2, directionality) tuples representing
3233
# one-way moves or two-way swaps of requests in batch
3334
MovedRequest = tuple[int, int, MoveDirectionality]
@@ -497,13 +498,14 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
497498
class MaxThinkTokensLogitsProcessor(LogitsProcessor):
498499
"""A logits processor that limits the maximum number of thinking tokens."""
499500

500-
def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device):
501+
def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
502+
device: torch.device):
501503
"""
502504
Args:
503-
think_start_token_id (int): Token ID for the start of thinking section.
504-
think_end_token_id (int): Token ID for the end of thinking section.
505-
pin_memory (bool): Whether to use pinned memory for tensors.
506-
device (torch.device): Device to use for tensor operations.
505+
reasoning_config: Configuration for reasoning, which includes
506+
the token IDs for thinking start and end.
507+
pin_memory (bool): Whether to use pinned memory for tensors.
508+
device (torch.device): Device to use for tensor operations.
507509
"""
508510
super().__init__()
509511
self.think_start_token_id = reasoning_config.think_start_token_id
@@ -519,19 +521,25 @@ def _find_last_token_index(self, tokens: list[int], token_id: int) -> int:
519521
return -1
520522

521523
def is_argmax_invariant(self) -> bool:
522-
"""This logits processor can change the outcome of greedy sampling
523-
by forcing that the thinking section ends after a certain number of tokens."""
524+
"""This logits processor can change the outcome of
525+
greedy sampling by forcing that the thinking section
526+
ends after a certain number of tokens."""
524527
return False
525528

526529
def update_state(self, batch_update: Optional[BatchUpdate]):
527530
if batch_update:
528-
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
529-
max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
531+
for (index, params, prompt_tok_ids,
532+
output_tok_ids) in batch_update.added:
533+
max_think_tokens = (params.max_think_tokens if isinstance(
534+
params, SamplingParams) else None)
530535
if max_think_tokens is not None:
531-
last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
532-
last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
536+
last_start = self._find_last_token_index(
537+
prompt_tok_ids, self.think_start_token_id)
538+
last_end = self._find_last_token_index(
539+
prompt_tok_ids, self.think_end_token_id)
533540
in_think = last_start > last_end
534-
count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0
541+
count = len(prompt_tok_ids) - (last_start +
542+
1) if in_think else 0
535543

536544
self._state[index] = {
537545
"in_think": in_think,
@@ -542,13 +550,14 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
542550
}
543551

544552
for index in batch_update.removed:
545-
self._state.pop(index, None)
553+
self._state.pop(index, {})
546554

547555
for i1, i2, direction in batch_update.moved:
548556
if direction == MoveDirectionality.SWAP:
549-
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
557+
self._state[i1], self._state[i2] = self._state[
558+
i2], self._state[i1]
550559
else:
551-
self._state[i2] = self._state.pop(i1, None)
560+
self._state[i2] = self._state.pop(i1, {})
552561

553562
# Update in_think and count for all active requests
554563
for state in self._state.values():
@@ -579,7 +588,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
579588
if not state:
580589
continue
581590

582-
if state["in_think"] and state["count"] >= state["max_think_tokens"]:
591+
if state["in_think"] and state["count"] >= state[
592+
"max_think_tokens"]:
583593
mask[index] = True
584594

585595
if mask.any():
@@ -589,8 +599,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
589599
return logits
590600

591601

592-
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
593-
device: torch.device, reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
602+
def init_builtin_logitsprocs(
603+
pin_memory_available: bool, max_num_reqs: int, device: torch.device,
604+
reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
594605
"""Construct 'builtin' vLLM logitsprocs which the engine
595606
loads by default.
596607
@@ -619,8 +630,7 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
619630
)
620631
return LogitsProcessorManager(
621632
non_argmax_invariant=[
622-
min_tokens_logitproc,
623-
logit_bias_logitproc,
633+
min_tokens_logitproc, logit_bias_logitproc,
624634
max_think_tokens_logitproc
625635
],
626636
argmax_invariant=[min_p_logitproc],

vllm/v1/worker/gpu_input_batch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ def _register_add_request(self, request: "CachedRequestState") -> int:
263263
params = (request.sampling_params
264264
if request.sampling_params else request.pooling_params)
265265
self.batch_update_builder.added.append(
266-
(req_index, params, request.prompt_token_ids, request.output_token_ids))
266+
(req_index, params, request.prompt_token_ids,
267+
request.output_token_ids))
267268
return req_index
268269

269270
def add_request(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.attention.backends.abstract import AttentionBackend
1919
from vllm.attention.layer import Attention
2020
from vllm.compilation.counter import compilation_counter
21-
from vllm.config import (CompilationLevel, VllmConfig,
21+
from vllm.config import (CompilationLevel, ReasoningConfig, VllmConfig,
2222
get_layers_from_vllm_config, update_config)
2323
from vllm.distributed.eplb.eplb_state import EplbState
2424
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@@ -39,8 +39,10 @@
3939
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
4040
from vllm.multimodal.utils import group_mm_inputs_by_modality
4141
from vllm.pooling_params import PoolingParams
42+
from vllm.reasoning import ReasoningParserManager
4243
from vllm.sampling_params import SamplingType
4344
from vllm.sequence import IntermediateTensors
45+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
4446
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
4547
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
4648
check_use_alibi, get_dtype_size,
@@ -71,10 +73,6 @@
7173
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
7274
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
7375

74-
from vllm.config import ReasoningConfig
75-
from vllm.reasoning import ReasoningParserManager
76-
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
77-
7876
if TYPE_CHECKING:
7977
import xgrammar as xgr
8078
import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
@@ -109,7 +107,8 @@ def __init__(
109107
self.prompt_adapter_config = vllm_config.prompt_adapter_config
110108
self.observability_config = vllm_config.observability_config
111109

112-
if self.vllm_config.decoding_config.reasoning_backend in ('deepseek_r1', 'qwen'):
110+
if self.vllm_config.decoding_config.reasoning_backend in (
111+
'deepseek_r1', 'qwen'):
113112
tokenizer = init_tokenizer_from_configs(
114113
model_config=self.vllm_config.model_config,
115114
scheduler_config=self.vllm_config.scheduler_config,
@@ -120,8 +119,9 @@ def __init__(
120119
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
121120
reasoning_backend)
122121
reasoning_parser = reasoner_cls(tokenizer=tokenizer)
123-
self.vllm_config.reasoning_config = ReasoningConfig(think_start_token_id=reasoning_parser.think_start_token_id,
124-
think_end_token_id=reasoning_parser.think_end_token_id)
122+
self.vllm_config.reasoning_config = ReasoningConfig(
123+
think_start_token_id=reasoning_parser.think_start_token_id,
124+
think_end_token_id=reasoning_parser.think_end_token_id)
125125

126126
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
127127
set_cpu_offload_max_bytes(

0 commit comments

Comments
 (0)