Skip to content

Commit 6b6d496

Browse files
optimize get_kv_cache_torch_dtype (vllm-project#18531)
Signed-off-by: idellzheng <idellzheng@tencent.com>
1 parent aaa4ac1 commit 6b6d496

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

vllm/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,16 +759,15 @@ def get_kv_cache_torch_dtype(
759759
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
760760
if isinstance(cache_dtype, str):
761761
if cache_dtype == "auto":
762-
if isinstance(model_dtype, str):
762+
if isinstance(model_dtype,
763+
str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
763764
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
764765
elif isinstance(model_dtype, torch.dtype):
765766
torch_dtype = model_dtype
766767
else:
767768
raise ValueError(f"Invalid model dtype: {model_dtype}")
768-
elif cache_dtype in ["half", "bfloat16", "float"]:
769+
elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
769770
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
770-
elif cache_dtype == "fp8":
771-
torch_dtype = torch.uint8
772771
else:
773772
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
774773
elif isinstance(cache_dtype, torch.dtype):

0 commit comments

Comments
 (0)