Skip to content

Commit 1cbd312

Browse files
committed
Fix lint
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 0ac11a8 commit 1cbd312

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import gc
55
from collections.abc import Iterable
6-
from typing import List, Optional, Union
6+
from typing import Optional, Union
77

88
import pytest
99
import torch
@@ -112,7 +112,7 @@ def __init__(
112112
*,
113113
vllm_config: VllmConfig,
114114
prefix: str = "",
115-
layers: List[nn.Module],
115+
layers: list[nn.Module],
116116
):
117117
super().__init__()
118118
self.layers = layers
@@ -162,7 +162,8 @@ def __init__(self,
162162
)
163163

164164
# Pre-allocate static buffers for CUDA graph
165-
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
165+
self.max_num_tokens =\
166+
vllm_config.scheduler_config.max_num_batched_tokens
166167
self.dtype = vllm_config.model_config.dtype
167168
self.device = next(self.parameters()).device
168169
self.hidden_size = vllm_config.model_config.get_hidden_size()

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4066,7 +4066,7 @@ class CompilationConfig:
40664066
- None (default): capture sizes are inferred from vllm config.
40674067
- list[int]: capture sizes are specified as given."""
40684068
cudagraph_share_memory_pool: bool = True
4069-
"""Whether to share a single global memory pool for each CUDA graph captured"""
4069+
"""Whether to share a single global memory pool for each graph capture"""
40704070
cudagraph_copy_inputs: bool = False
40714071
"""Whether to copy input tensors for
40724072
cudagraph. If the caller can guarantee that the same input buffers

vllm/v1/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def build_skip_prefill(
223223
# num_decode_tokens: [1, 2, 1]
224224
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
225225

226-
# Calculate new query_start_loc only considering tokens in decode_indices
226+
# Calculate new query_start_loc with tokens in decode_indices
227227
# decode_query_start_loc: [0, 1, 3, 4]
228228
decode_query_start_loc = torch.empty(num_reqs + 1,
229229
device=query_start_loc.device,

0 commit comments

Comments
 (0)