12
12
13
13
import torch
14
14
15
- import vllm .envs as envs
16
- from vllm import _custom_ops as ops
17
15
from vllm .config import VllmConfig
18
16
from vllm .distributed .kv_transfer .kv_connector .base import KVConnectorBase
17
+ from vllm .distributed .kv_transfer .kv_connector .utils import (
18
+ model_aware_kv_ops_helper as kv_helper )
19
19
from vllm .distributed .kv_transfer .kv_lookup_buffer .simple_buffer import (
20
20
SimpleBuffer )
21
21
from vllm .logger import init_logger
@@ -37,9 +37,7 @@ def __init__(
37
37
):
38
38
39
39
self .config = config .kv_transfer_config
40
- self .tp_size = config .parallel_config .tensor_parallel_size
41
- self .is_deepseek_mla = config .model_config .is_deepseek_mla
42
- self .use_mla_opt = not envs .VLLM_MLA_DISABLE
40
+ self .kv_helper = kv_helper (config )
43
41
44
42
if self .config .kv_connector == "PyNcclConnector" :
45
43
from vllm .distributed .kv_transfer .kv_pipe .pynccl_pipe import (
@@ -165,31 +163,7 @@ def send_kv_caches_and_hidden_states(
165
163
num_prefill_tokens = model_input .attn_metadata .num_prefill_tokens
166
164
start_layer = model_executable .model .start_layer
167
165
end_layer = model_executable .model .end_layer
168
-
169
- model_config = model_executable .model .config
170
- num_heads = int (model_config .num_key_value_heads / self .tp_size )
171
- hidden_size = model_config .hidden_size
172
- num_attention_heads = model_config .num_attention_heads
173
-
174
- # Deepseek's MLA (Multi-head Latent Attention) uses two different
175
- # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
176
- # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
177
- # resulting in a kv_cache shape of [num_blks, blk_size, 1,
178
- # kv_lora_rank + qk_rope_head_dim].
179
- # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
180
- # to a kv_cache shape of [2, num_blks, blk_size,
181
- # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
182
- # For more details, see vllm/attention/backends/mla/common.py.
183
- if self .is_deepseek_mla and self .use_mla_opt :
184
- head_size = model_config .kv_lora_rank + \
185
- model_config .qk_rope_head_dim
186
- num_heads = 1
187
- elif self .is_deepseek_mla and not self .use_mla_opt :
188
- head_size = model_config .qk_nope_head_dim + \
189
- model_config .qk_rope_head_dim
190
- else :
191
- head_size = getattr (model_config , "head_dim" ,
192
- int (hidden_size // num_attention_heads ))
166
+ num_heads , head_size = self .kv_helper .get_model_args (model_executable )
193
167
194
168
# query_lens contains new KV caches that are added to vLLM.
195
169
# so we will send them to decode instance
@@ -212,13 +186,8 @@ def send_kv_caches_and_hidden_states(
212
186
213
187
for layer_id in range (start_layer , end_layer ):
214
188
kv_cache = kv_caches [layer_id - start_layer ]
215
-
216
- if self .is_deepseek_mla and self .use_mla_opt :
217
- key_cache = kv_cache .reshape (- 1 , num_heads , head_size )
218
- value_cache = kv_cache .reshape (- 1 , num_heads , head_size )
219
- else :
220
- key_cache = kv_cache [0 ].reshape (- 1 , num_heads , head_size )
221
- value_cache = kv_cache [1 ].reshape (- 1 , num_heads , head_size )
189
+ key_cache , value_cache = self .kv_helper .get_kv_from_cache (
190
+ kv_cache , num_heads , head_size )
222
191
223
192
current_slot_mapping = slot_mapping_flat [start_pos :end_pos ]
224
193
@@ -248,12 +217,12 @@ def recv_kv_caches_and_hidden_states(
248
217
# and hidden states.
249
218
bypass_model_exec = True
250
219
251
- model_config = model_executable .model .config
252
-
253
220
input_tokens_tensor = model_input .input_tokens
254
221
seq_lens = model_input .attn_metadata .seq_lens
255
222
num_prefill_tokens = model_input .attn_metadata .num_prefill_tokens
256
223
slot_mapping = model_input .attn_metadata .slot_mapping .flatten ()
224
+ start_layer = model_executable .model .start_layer
225
+ end_layer = model_executable .model .end_layer
257
226
258
227
hidden_or_intermediate_states_for_one_req = []
259
228
@@ -312,41 +281,19 @@ def recv_kv_caches_and_hidden_states(
312
281
end_pos = start_pos + num_computed_tokens
313
282
314
283
# put received KV caches into paged memory
315
- for i in range (model_executable .model .start_layer ,
316
- model_executable .model .end_layer ):
317
-
318
- kv_cache = kv_caches [i - model_executable .model .start_layer ]
319
- layer = model_executable .model .layers [i ]
320
-
321
- if self .is_deepseek_mla and self .use_mla_opt :
322
- layer .self_attn .attn = layer .self_attn .mla_attn
323
- k_c_normed_k_pe = keys [
324
- i - model_executable .model .start_layer ].to (
325
- kv_cache .device ).squeeze (1 )
326
- k_c_normed = k_c_normed_k_pe [:, :model_config .kv_lora_rank ]
327
- k_pe = k_c_normed_k_pe [:, model_config .kv_lora_rank :]
328
- ops .concat_and_cache_mla (
329
- k_c_normed ,
330
- k_pe ,
331
- kv_cache ,
332
- slot_mapping [start_pos :end_pos ],
333
- layer .self_attn .attn .kv_cache_dtype ,
334
- layer .self_attn .attn ._k_scale ,
335
- )
336
- else :
337
- key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
338
- ops .reshape_and_cache_flash (
339
- keys [i - model_executable .model .start_layer ].to (
340
- key_cache .device ),
341
- values [i - model_executable .model .start_layer ].to (
342
- value_cache .device ),
343
- key_cache ,
344
- value_cache ,
345
- slot_mapping [start_pos :end_pos ],
346
- layer .self_attn .attn .kv_cache_dtype ,
347
- layer .self_attn .attn ._k_scale ,
348
- layer .self_attn .attn ._v_scale ,
349
- )
284
+ for cur_layer in range (start_layer , end_layer ):
285
+
286
+ layer_id = cur_layer - start_layer
287
+ kv_cache = kv_caches [layer_id ]
288
+ layer = model_executable .model .layers [cur_layer ]
289
+
290
+ # get remote kvcache
291
+ remote_k , remote_v = keys [layer_id ], values [layer_id ]
292
+
293
+ self .kv_helper .put_kv_to_cache (model_executable , remote_k ,
294
+ remote_v , layer , kv_cache ,
295
+ slot_mapping , start_pos ,
296
+ end_pos )
350
297
351
298
hidden_or_intermediate_states_for_one_req .append (hidden )
352
299
0 commit comments