Skip to content

Commit 0d49483

Browse files
authored
[TPU] fix kv cache dtype in model runner (#19244)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 90b78ec commit 0d49483

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
PlaceholderRange)
3030
from vllm.multimodal.utils import group_mm_inputs_by_modality
3131
from vllm.sequence import IntermediateTensors
32-
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
32+
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
33+
is_pin_memory_available)
3334
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
3435
PallasMetadata)
3536
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -138,6 +139,11 @@ def __init__(
138139

139140
self.pin_memory = is_pin_memory_available()
140141
self.dtype = self.model_config.dtype
142+
if cache_config.cache_dtype == "auto":
143+
self.kv_cache_dtype = self.dtype
144+
else:
145+
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
146+
cache_config.cache_dtype]
141147
self._hidden_states_dtype = self.dtype
142148

143149
self.is_multimodal_model = model_config.is_multimodal_model
@@ -480,7 +486,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
480486
block_size=block_size,
481487
num_kv_heads=attn_module.num_kv_heads,
482488
head_size=attn_module.head_size,
483-
dtype=attn_module.dtype,
489+
dtype=self.kv_cache_dtype,
484490
sliding_window=attn_module.sliding_window,
485491
use_mla=False,
486492
)
@@ -489,7 +495,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
489495
block_size=block_size,
490496
num_kv_heads=attn_module.num_kv_heads,
491497
head_size=attn_module.head_size,
492-
dtype=attn_module.dtype,
498+
dtype=self.kv_cache_dtype,
493499
use_mla=False,
494500
)
495501
elif attn_module.attn_type in (AttentionType.ENCODER,

0 commit comments

Comments
 (0)