Skip to content

Commit 55ddaa0

Browse files
committed
Fix pre-commit
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent b6cd35d commit 55ddaa0

File tree

4 files changed

+10
-14
lines changed

4 files changed

+10
-14
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def load_weights(self, weights: Iterable[tuple[str,
317317

318318
@pytest.fixture
319319
def test_prompts():
320+
"""
321+
Adapted from tests/v1/e2e/test_spec_decode.py
322+
"""
320323
prompt_types = ["repeat", "sentence"]
321324
# Setting higher num prompts increases the chance of numerics mismatch
322325
# due to matrix multiplication numerics depending on batch dimension
@@ -326,8 +329,6 @@ def test_prompts():
326329
random.seed(0)
327330
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
328331

329-
# Generate a mixed batch of prompts, some of which can be easily
330-
# predicted by n-gram matching and some which likely cannot.
331332
for kind in random_prompt_type_choices:
332333
word_choices = ["test", "temp", "hello", "where"]
333334
word = random.choice(word_choices)

vllm/compilation/decorators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
def skip_torch_compile(cls: _T) -> _T:
2727
cls._skip_compile_vllm = True
28-
for base in cls.__bases__:
29-
setattr(base,"_skip_compile_vllm",True)
3028
return cls
3129

3230

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ def build(
284284
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
285285
query_start_loc = common_attn_metadata.query_start_loc
286286
query_start_loc_np = common_attn_metadata.query_start_loc_np
287+
if query_start_loc_np is None:
288+
query_start_loc_np = self.runner.query_start_loc_np[:num_reqs + 1]
287289
seq_lens = common_attn_metadata.seq_lens
288290
block_table = self.block_table
289291
block_table_tensor = block_table.get_device_tensor()[:num_reqs]

vllm/v1/attention/backends/utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ class CommonAttentionMetadata:
3737
query_start_loc: torch.Tensor
3838
"""(batch_size + 1,), the start location of each request in query Tensor"""
3939

40-
query_start_loc_np: np.ndarray
41-
"""(batch_size + 1,), numpy version of query_start_loc on the CPU"""
42-
4340
seq_lens: torch.Tensor
4441
"""(batch_size,), the length of each request including both computed tokens
4542
and newly scheduled tokens"""
@@ -54,6 +51,9 @@ class CommonAttentionMetadata:
5451
decode_indices: Optional[torch.Tensor] = None
5552
"""indices used for decoding"""
5653

54+
query_start_loc_np: Optional[np.ndarray] = None
55+
"""(batch_size + 1,), numpy equivalent of query_start_loc on the CPU"""
56+
5757

5858
M = TypeVar("M")
5959

@@ -63,13 +63,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
6363
full_cudagraph_supported: ClassVar[bool] = False
6464

6565
@abstractmethod
66-
def build(
67-
self,
68-
common_prefix_len: int,
69-
common_attn_metadata: CommonAttentionMetadata,
70-
decode_only_common_attn_metadata: Optional[
71-
CommonAttentionMetadata] = None,
72-
) -> M:
66+
def build(self, common_prefix_len: int,
67+
common_attn_metadata: CommonAttentionMetadata) -> M:
7368
"""
7469
Central method that builds attention metadata.
7570
Some builders (MLA) require reorder_batch to be called prior to build.

0 commit comments

Comments
 (0)