|
18 | 18 | from vllm.lora.request import LoRARequest
|
19 | 19 | from vllm.model_executor import set_random_seed
|
20 | 20 | from vllm.sequence import ExecuteModelRequest
|
21 |
| -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache |
| 21 | +from vllm.utils import bind_kv_cache |
22 | 22 | from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
23 | 23 | from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
24 | 24 | from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
@@ -54,13 +54,8 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
|
54 | 54 | # in the scheduler.
|
55 | 55 | self.num_cpu_blocks = cache_config.num_gpu_blocks
|
56 | 56 |
|
57 |
| - if cache_config.cache_dtype == "auto": |
58 |
| - self.dtype = model_config.dtype |
59 |
| - elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: |
60 |
| - self.dtype = torch.float8_e5m2 |
61 |
| - else: |
62 |
| - raise NotImplementedError(f"Unsupported KV cache type " |
63 |
| - f"{cache_config.cache_dtype}.") |
| 57 | + self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, |
| 58 | + model_config) |
64 | 59 |
|
65 | 60 | # Get attention backend.
|
66 | 61 | self.attn_backend = get_attn_backend(
|
@@ -97,24 +92,31 @@ def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
97 | 92 | def copy(self, src_to_dsts: torch.Tensor) -> None:
|
98 | 93 | self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
99 | 94 |
|
| 95 | + @staticmethod |
| 96 | + def get_kv_cache_dtype(cache_config: CacheConfig, |
| 97 | + model_config: ModelConfig): |
| 98 | + if cache_config.cache_dtype == "auto": |
| 99 | + return model_config.dtype |
| 100 | + elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: |
| 101 | + return torch.float8_e5m2 |
| 102 | + else: |
| 103 | + raise NotImplementedError(f"Unsupported KV cache type " |
| 104 | + f"{cache_config.cache_dtype}.") |
| 105 | + |
100 | 106 | @staticmethod
|
101 | 107 | def get_cache_block_size(
|
102 |
| - block_size: int, |
103 |
| - cache_dtype: str, |
| 108 | + cache_config: CacheConfig, |
104 | 109 | model_config: ModelConfig,
|
105 | 110 | parallel_config: ParallelConfig,
|
106 | 111 | ) -> int:
|
107 | 112 | head_size = model_config.get_head_size()
|
108 | 113 | num_heads = model_config.get_num_kv_heads(parallel_config)
|
109 | 114 | num_layers = model_config.get_num_layers(parallel_config)
|
110 | 115 |
|
111 |
| - key_cache_block = block_size * num_heads * head_size |
| 116 | + key_cache_block = cache_config.block_size * num_heads * head_size |
112 | 117 | value_cache_block = key_cache_block if not model_config.use_mla else 0
|
113 | 118 | total = num_layers * (key_cache_block + value_cache_block)
|
114 |
| - if cache_dtype == "auto": |
115 |
| - dtype = model_config.dtype |
116 |
| - else: |
117 |
| - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] |
| 119 | + dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config) |
118 | 120 | dtype_size = torch.tensor([], dtype=dtype).element_size()
|
119 | 121 | return dtype_size * total
|
120 | 122 |
|
@@ -399,9 +401,9 @@ def init_distributed_environment(self) -> None:
|
399 | 401 | def get_cache_block_size_bytes(self) -> int:
|
400 | 402 | """Return the size in bytes of a single KV cache block.
|
401 | 403 | """
|
402 |
| - return CPUCacheEngine.get_cache_block_size( |
403 |
| - self.cache_config.block_size, self.cache_config.cache_dtype, |
404 |
| - self.model_config, self.parallel_config) |
| 404 | + return CPUCacheEngine.get_cache_block_size(self.cache_config, |
| 405 | + self.model_config, |
| 406 | + self.parallel_config) |
405 | 407 |
|
406 | 408 | def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
407 | 409 | """Return CPUs id binding based on NUMA nodes.
|
|
0 commit comments