Skip to content

Commit e171dd5

Browse files
committed
Fix pre-commit
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 9f2a48c commit e171dd5

File tree

4 files changed

+10
-14
lines changed

4 files changed

+10
-14
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def load_weights(self, weights: Iterable[tuple[str,
317317

318318
@pytest.fixture
319319
def test_prompts():
320+
"""
321+
Adapted from tests/v1/e2e/test_spec_decode.py
322+
"""
320323
prompt_types = ["repeat", "sentence"]
321324
# Setting higher num prompts increases the chance of numerics mismatch
322325
# due to matrix multiplication numerics depending on batch dimension
@@ -326,8 +329,6 @@ def test_prompts():
326329
random.seed(0)
327330
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
328331

329-
# Generate a mixed batch of prompts, some of which can be easily
330-
# predicted by n-gram matching and some which likely cannot.
331332
for kind in random_prompt_type_choices:
332333
word_choices = ["test", "temp", "hello", "where"]
333334
word = random.choice(word_choices)

vllm/compilation/decorators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
def skip_torch_compile(cls: _T) -> _T:
2727
cls._skip_compile_vllm = True
28-
for base in cls.__bases__:
29-
setattr(base,"_skip_compile_vllm",True)
3028
return cls
3129

3230

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def build(
273273
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
274274
query_start_loc = common_attn_metadata.query_start_loc
275275
query_start_loc_np = common_attn_metadata.query_start_loc_np
276+
if query_start_loc_np is None:
277+
query_start_loc_np = self.runner.query_start_loc_np[:num_reqs + 1]
276278
seq_lens = common_attn_metadata.seq_lens
277279
block_table = self.block_table
278280
block_table_tensor = block_table.get_device_tensor()[:num_reqs]

vllm/v1/attention/backends/utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ class CommonAttentionMetadata:
3333
query_start_loc: torch.Tensor
3434
"""(batch_size + 1,), the start location of each request in query Tensor"""
3535

36-
query_start_loc_np: np.ndarray
37-
"""(batch_size + 1,), numpy version of query_start_loc on the CPU"""
38-
3936
seq_lens: torch.Tensor
4037
"""(batch_size,), the length of each request including both computed tokens
4138
and newly scheduled tokens"""
@@ -50,6 +47,9 @@ class CommonAttentionMetadata:
5047
decode_indices: Optional[torch.Tensor] = None
5148
"""indices used for decoding"""
5249

50+
query_start_loc_np: Optional[np.ndarray] = None
51+
"""(batch_size + 1,), numpy equivalent of query_start_loc on the CPU"""
52+
5353

5454
M = TypeVar("M")
5555

@@ -59,13 +59,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
5959
full_cudagraph_supported: ClassVar[bool] = False
6060

6161
@abstractmethod
62-
def build(
63-
self,
64-
common_prefix_len: int,
65-
common_attn_metadata: CommonAttentionMetadata,
66-
decode_only_common_attn_metadata: Optional[
67-
CommonAttentionMetadata] = None,
68-
) -> M:
62+
def build(self, common_prefix_len: int,
63+
common_attn_metadata: CommonAttentionMetadata) -> M:
6964
"""
7065
Central method that builds attention metadata.
7166
Some builders (MLA) require reorder_batch to be called prior to build.

0 commit comments

Comments
 (0)