Skip to content

Commit 00f9b31

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent e28c49e commit 00f9b31

File tree

4 files changed

+15
-16
lines changed

4 files changed

+15
-16
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88

99
import torch
1010

11+
import vllm.envs as envs
1112
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
1213
BatchPrefillWithPagedKVCacheWrapper,
1314
MultiLevelCascadeAttentionWrapper)
1415
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
15-
16-
import vllm.envs as envs
1716
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1817
AttentionType)
1918
from vllm.config import VllmConfig
@@ -23,7 +22,7 @@
2322
from vllm.v1.attention.backends.utils import (
2423
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
2524
get_kv_cache_layout, get_per_layer_parameters,
26-
infer_global_hyperparameters, reoder_batch_to_split_decodes_and_prefills,
25+
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
2726
split_decodes_and_prefills)
2827
from vllm.v1.kv_cache_interface import AttentionSpec
2928

@@ -237,13 +236,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
237236
self.global_hyperparameters: Optional[PerLayerParameters] = None
238237

239238
self.vllm_config = vllm_config
239+
self.cache_config = vllm_config.cache_config
240240
self.kv_cache_spec = kv_cache_spec
241241

242242
def reorder_batch(self, input_batch: InputBatch,
243243
scheduler_output: SchedulerOutput) -> bool:
244-
return reoder_batch_to_split_decodes_and_prefills(input_batch,
245-
scheduler_output,
246-
decode_threshold=1)
244+
return reorder_batch_to_split_decodes_and_prefills(input_batch,
245+
scheduler_output,
246+
decode_threshold=1)
247247

248248
def _get_workspace_buffer(self):
249249
if self._workspace_buffer is None:
@@ -384,7 +384,7 @@ def build(self,
384384
page_size = self.kv_cache_spec.block_size
385385
device = self.device
386386
qo_indptr = common_attn_metadata.query_start_loc
387-
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
387+
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
388388
seq_lens = common_attn_metadata.seq_lens
389389
block_table_tensor = common_attn_metadata.block_table_tensor
390390

@@ -431,7 +431,7 @@ def build(self,
431431
paged_kv_last_page_len = seq_lens % page_size
432432
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
433433
page_size, paged_kv_last_page_len)
434-
cache_dtype = self.runner.cache_config.cache_dtype
434+
cache_dtype = self.cache_config.cache_dtype
435435
if cache_dtype.startswith("fp8"):
436436
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
437437
cache_dtype)

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import torch
88

9-
from vllm.config import VllmConfig
109
from vllm.attention.backends.abstract import AttentionBackend
10+
from vllm.config import VllmConfig
1111
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
1212
CommonAttentionMetadata)
1313
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -90,8 +90,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
9090
device: torch.device):
9191
assert isinstance(kv_cache_spec, MambaSpec)
9292
self.kv_cache_spec = kv_cache_spec
93-
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size(
94-
)
93+
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
9594
assert self.chunk_size is not None, (
9695
"chunk_size needs to be set in the model config for Mamba2 models")
9796

vllm/v1/attention/backends/mla/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@
211211
from vllm.v1.attention.backends.utils import (
212212
AttentionMetadataBuilder, CommonAttentionMetadata,
213213
get_per_layer_parameters, infer_global_hyperparameters,
214-
reoder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
214+
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
215215
from vllm.v1.kv_cache_interface import AttentionSpec
216216

217217
try:
@@ -525,9 +525,9 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
525525

526526
def reorder_batch(self, input_batch: "InputBatch",
527527
scheduler_output: "SchedulerOutput") -> bool:
528-
return reoder_batch_to_split_decodes_and_prefills(input_batch,
529-
scheduler_output,
530-
decode_threshold=1)
528+
return reorder_batch_to_split_decodes_and_prefills(input_batch,
529+
scheduler_output,
530+
decode_threshold=1)
531531

532532
def _build_decode(self, block_table_tensor: torch.Tensor,
533533
seq_lens: torch.Tensor):

vllm/v1/attention/backends/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def split_decodes_and_prefills(
426426
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
427427

428428

429-
def reoder_batch_to_split_decodes_and_prefills(
429+
def reorder_batch_to_split_decodes_and_prefills(
430430
input_batch: "InputBatch",
431431
scheduler_output: "SchedulerOutput",
432432
decode_threshold: int = 1,

0 commit comments

Comments
 (0)