|
8 | 8 |
|
9 | 9 | import torch
|
10 | 10 |
|
| 11 | +import vllm.envs as envs |
11 | 12 | from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
12 | 13 | BatchPrefillWithPagedKVCacheWrapper,
|
13 | 14 | MultiLevelCascadeAttentionWrapper)
|
14 | 15 | from flashinfer.decode import trtllm_batch_decode_with_kv_cache
|
15 |
| - |
16 |
| -import vllm.envs as envs |
17 | 16 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
18 | 17 | AttentionType)
|
19 | 18 | from vllm.config import VllmConfig
|
|
23 | 22 | from vllm.v1.attention.backends.utils import (
|
24 | 23 | AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
25 | 24 | get_kv_cache_layout, get_per_layer_parameters,
|
26 |
| - infer_global_hyperparameters, reoder_batch_to_split_decodes_and_prefills, |
| 25 | + infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, |
27 | 26 | split_decodes_and_prefills)
|
28 | 27 | from vllm.v1.kv_cache_interface import AttentionSpec
|
29 | 28 |
|
@@ -237,13 +236,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
237 | 236 | self.global_hyperparameters: Optional[PerLayerParameters] = None
|
238 | 237 |
|
239 | 238 | self.vllm_config = vllm_config
|
| 239 | + self.cache_config = vllm_config.cache_config |
240 | 240 | self.kv_cache_spec = kv_cache_spec
|
241 | 241 |
|
242 | 242 | def reorder_batch(self, input_batch: InputBatch,
|
243 | 243 | scheduler_output: SchedulerOutput) -> bool:
|
244 |
| - return reoder_batch_to_split_decodes_and_prefills(input_batch, |
245 |
| - scheduler_output, |
246 |
| - decode_threshold=1) |
| 244 | + return reorder_batch_to_split_decodes_and_prefills(input_batch, |
| 245 | + scheduler_output, |
| 246 | + decode_threshold=1) |
247 | 247 |
|
248 | 248 | def _get_workspace_buffer(self):
|
249 | 249 | if self._workspace_buffer is None:
|
@@ -384,7 +384,7 @@ def build(self,
|
384 | 384 | page_size = self.kv_cache_spec.block_size
|
385 | 385 | device = self.device
|
386 | 386 | qo_indptr = common_attn_metadata.query_start_loc
|
387 |
| - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) |
| 387 | + max_seq_len = common_attn_metadata.seq_lens_cpu.max() |
388 | 388 | seq_lens = common_attn_metadata.seq_lens
|
389 | 389 | block_table_tensor = common_attn_metadata.block_table_tensor
|
390 | 390 |
|
@@ -431,7 +431,7 @@ def build(self,
|
431 | 431 | paged_kv_last_page_len = seq_lens % page_size
|
432 | 432 | paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
433 | 433 | page_size, paged_kv_last_page_len)
|
434 |
| - cache_dtype = self.runner.cache_config.cache_dtype |
| 434 | + cache_dtype = self.cache_config.cache_dtype |
435 | 435 | if cache_dtype.startswith("fp8"):
|
436 | 436 | kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
437 | 437 | cache_dtype)
|
|
0 commit comments