File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -759,16 +759,15 @@ def get_kv_cache_torch_dtype(
759
759
model_dtype : Optional [Union [str , torch .dtype ]] = None ) -> torch .dtype :
760
760
if isinstance (cache_dtype , str ):
761
761
if cache_dtype == "auto" :
762
- if isinstance (model_dtype , str ):
762
+ if isinstance (model_dtype ,
763
+ str ) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE :
763
764
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE [model_dtype ]
764
765
elif isinstance (model_dtype , torch .dtype ):
765
766
torch_dtype = model_dtype
766
767
else :
767
768
raise ValueError (f"Invalid model dtype: { model_dtype } " )
768
- elif cache_dtype in [ "half" , "bfloat16" , "float" ] :
769
+ elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE :
769
770
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE [cache_dtype ]
770
- elif cache_dtype == "fp8" :
771
- torch_dtype = torch .uint8
772
771
else :
773
772
raise ValueError (f"Invalid kv cache dtype: { cache_dtype } " )
774
773
elif isinstance (cache_dtype , torch .dtype ):
You can’t perform that action at this time.
0 commit comments