|
7 | 7 | from typing import TYPE_CHECKING, Any, Optional
|
8 | 8 |
|
9 | 9 | import torch
|
| 10 | + |
10 | 11 | from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
11 | 12 | BatchPrefillWithPagedKVCacheWrapper,
|
12 | 13 | MultiLevelCascadeAttentionWrapper)
|
|
19 | 20 | from vllm.logger import init_logger
|
20 | 21 | from vllm.platforms import current_platform
|
21 | 22 | from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
22 |
| -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, |
23 |
| - CommonAttentionMetadata, |
24 |
| - PerLayerParameters, |
25 |
| - get_kv_cache_layout, |
26 |
| - get_per_layer_parameters, |
27 |
| - infer_global_hyperparameters, |
28 |
| - reoder_batch_to_split_decodes_and_prefills, |
29 |
| - split_decodes_and_prefills) |
| 23 | +from vllm.v1.attention.backends.utils import ( |
| 24 | + AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, |
| 25 | + get_kv_cache_layout, get_per_layer_parameters, |
| 26 | + infer_global_hyperparameters, reoder_batch_to_split_decodes_and_prefills, |
| 27 | + split_decodes_and_prefills) |
30 | 28 | from vllm.v1.kv_cache_interface import AttentionSpec
|
31 | 29 |
|
32 | 30 | if TYPE_CHECKING:
|
@@ -450,7 +448,7 @@ def build(self,
|
450 | 448 | num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
451 | 449 | head_dim=self.kv_cache_spec.head_size,
|
452 | 450 | page_size=page_size,
|
453 |
| - kv_data_type=self.kv_cache_spec.dtype, |
| 451 | + kv_data_type=kv_cache_dtype, |
454 | 452 | q_data_type=self.vllm_config.model_config.dtype,
|
455 | 453 | slot_mapping=common_attn_metadata.slot_mapping,
|
456 | 454 | num_decodes=num_decodes,
|
|
0 commit comments