29
29
PlaceholderRange )
30
30
from vllm .multimodal .utils import group_mm_inputs_by_modality
31
31
from vllm .sequence import IntermediateTensors
32
- from vllm .utils import LayerBlockType , cdiv , is_pin_memory_available
32
+ from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , LayerBlockType , cdiv ,
33
+ is_pin_memory_available )
33
34
from vllm .v1 .attention .backends .pallas import (PallasAttentionBackend ,
34
35
PallasMetadata )
35
36
from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
@@ -138,6 +139,11 @@ def __init__(
138
139
139
140
self .pin_memory = is_pin_memory_available ()
140
141
self .dtype = self .model_config .dtype
142
+ if cache_config .cache_dtype == "auto" :
143
+ self .kv_cache_dtype = self .dtype
144
+ else :
145
+ self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
146
+ cache_config .cache_dtype ]
141
147
self ._hidden_states_dtype = self .dtype
142
148
143
149
self .is_multimodal_model = model_config .is_multimodal_model
@@ -480,7 +486,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
480
486
block_size = block_size ,
481
487
num_kv_heads = attn_module .num_kv_heads ,
482
488
head_size = attn_module .head_size ,
483
- dtype = attn_module . dtype ,
489
+ dtype = self . kv_cache_dtype ,
484
490
sliding_window = attn_module .sliding_window ,
485
491
use_mla = False ,
486
492
)
@@ -489,7 +495,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
489
495
block_size = block_size ,
490
496
num_kv_heads = attn_module .num_kv_heads ,
491
497
head_size = attn_module .head_size ,
492
- dtype = attn_module . dtype ,
498
+ dtype = self . kv_cache_dtype ,
493
499
use_mla = False ,
494
500
)
495
501
elif attn_module .attn_type in (AttentionType .ENCODER ,
0 commit comments