Skip to content

Commit 27a8b26

Browse files
andyxningsfeng33
authored andcommitted
[Kernel] refactor cpu worker v0 cache dtype (vllm-project#20080)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
1 parent 34c79e0 commit 27a8b26

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

vllm/worker/cpu_worker.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.lora.request import LoRARequest
1919
from vllm.model_executor import set_random_seed
2020
from vllm.sequence import ExecuteModelRequest
21-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
21+
from vllm.utils import bind_kv_cache
2222
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
2323
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
2424
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
@@ -54,13 +54,8 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
5454
# in the scheduler.
5555
self.num_cpu_blocks = cache_config.num_gpu_blocks
5656

57-
if cache_config.cache_dtype == "auto":
58-
self.dtype = model_config.dtype
59-
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
60-
self.dtype = torch.float8_e5m2
61-
else:
62-
raise NotImplementedError(f"Unsupported KV cache type "
63-
f"{cache_config.cache_dtype}.")
57+
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
58+
model_config)
6459

6560
# Get attention backend.
6661
self.attn_backend = get_attn_backend(
@@ -97,24 +92,31 @@ def swap_out(self, src_to_dst: torch.Tensor) -> None:
9792
def copy(self, src_to_dsts: torch.Tensor) -> None:
9893
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
9994

95+
@staticmethod
96+
def get_kv_cache_dtype(cache_config: CacheConfig,
97+
model_config: ModelConfig):
98+
if cache_config.cache_dtype == "auto":
99+
return model_config.dtype
100+
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
101+
return torch.float8_e5m2
102+
else:
103+
raise NotImplementedError(f"Unsupported KV cache type "
104+
f"{cache_config.cache_dtype}.")
105+
100106
@staticmethod
101107
def get_cache_block_size(
102-
block_size: int,
103-
cache_dtype: str,
108+
cache_config: CacheConfig,
104109
model_config: ModelConfig,
105110
parallel_config: ParallelConfig,
106111
) -> int:
107112
head_size = model_config.get_head_size()
108113
num_heads = model_config.get_num_kv_heads(parallel_config)
109114
num_layers = model_config.get_num_layers(parallel_config)
110115

111-
key_cache_block = block_size * num_heads * head_size
116+
key_cache_block = cache_config.block_size * num_heads * head_size
112117
value_cache_block = key_cache_block if not model_config.use_mla else 0
113118
total = num_layers * (key_cache_block + value_cache_block)
114-
if cache_dtype == "auto":
115-
dtype = model_config.dtype
116-
else:
117-
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
119+
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
118120
dtype_size = torch.tensor([], dtype=dtype).element_size()
119121
return dtype_size * total
120122

@@ -399,9 +401,9 @@ def init_distributed_environment(self) -> None:
399401
def get_cache_block_size_bytes(self) -> int:
400402
"""Return the size in bytes of a single KV cache block.
401403
"""
402-
return CPUCacheEngine.get_cache_block_size(
403-
self.cache_config.block_size, self.cache_config.cache_dtype,
404-
self.model_config, self.parallel_config)
404+
return CPUCacheEngine.get_cache_block_size(self.cache_config,
405+
self.model_config,
406+
self.parallel_config)
405407

406408
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
407409
"""Return CPUs id binding based on NUMA nodes.

0 commit comments

Comments
 (0)