Skip to content

Commit c4f1be1

Browse files
billishyahaoyangw-dev
authored andcommitted
[Feature] add model aware kv ops helper (vllm-project#16020)
Signed-off-by: billishyahao <bill.he@amd.com> Signed-off-by: Yang Wang <elainewy@meta.com>
1 parent 257b00e commit c4f1be1

File tree

3 files changed

+123
-101
lines changed

3 files changed

+123
-101
lines changed

vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""
33
MooncakeStore Connector for Distributed Machine Learning Inference
4-
54
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
65
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
76
database-style KVStore.
@@ -11,9 +10,10 @@
1110

1211
import torch
1312

14-
from vllm import _custom_ops as ops
1513
from vllm.config import VllmConfig
1614
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
15+
from vllm.distributed.kv_transfer.kv_connector.utils import (
16+
model_aware_kv_ops_helper as kv_helper)
1717
from vllm.logger import init_logger
1818
from vllm.sequence import IntermediateTensors
1919

@@ -32,8 +32,7 @@ def __init__(
3232
config: VllmConfig,
3333
):
3434
self.config = config.kv_transfer_config
35-
self.tp_size = config.parallel_config.tensor_parallel_size
36-
35+
self.kv_helper = kv_helper(config)
3736
self.local_tp_rank = local_rank
3837

3938
# Init kv_store
@@ -80,12 +79,7 @@ def send_kv_caches_and_hidden_states(
8079
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
8180
start_layer = model_executable.model.start_layer
8281
end_layer = model_executable.model.end_layer
83-
84-
model_config = model_executable.model.config
85-
num_heads = int(model_config.num_key_value_heads / self.tp_size)
86-
hidden_size = model_config.hidden_size
87-
num_attention_heads = model_config.num_attention_heads
88-
head_size = int(hidden_size / num_attention_heads)
82+
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
8983

9084
for idx, slen in enumerate(seq_lens):
9185
start_pos = sum(seq_lens[:idx])
@@ -97,10 +91,8 @@ def send_kv_caches_and_hidden_states(
9791

9892
for layer_id in range(start_layer, end_layer):
9993
kv_cache = kv_caches[layer_id - start_layer]
100-
101-
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
102-
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
103-
94+
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
95+
kv_cache, num_heads, head_size)
10496
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
10597

10698
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
@@ -173,22 +165,15 @@ def recv_kv_caches_and_hidden_states(
173165
layer = model_executable.model.layers[layer_id]
174166
# get kvcache object
175167
kv_cache = kv_caches[layer_id - start_layer]
176-
key_cache, value_cache = kv_cache[0], kv_cache[1]
177-
# get remote kvcache
178168

169+
# get remote kvcache
179170
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
180171
layer_id]
181-
# use ops.reshape_and_cache_flash to put kv into kvcache
182-
ops.reshape_and_cache_flash(
183-
remote_k.to(key_cache.device),
184-
remote_v.to(value_cache.device),
185-
key_cache,
186-
value_cache,
187-
slot_mapping[start_pos:end_pos],
188-
layer.self_attn.attn.kv_cache_dtype,
189-
layer.self_attn.attn._k_scale,
190-
layer.self_attn.attn._v_scale,
191-
)
172+
173+
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
174+
remote_v, layer, kv_cache,
175+
slot_mapping, start_pos,
176+
end_pos)
192177

193178
hidden_or_intermediate_states_for_one_req.append(hidden)
194179

vllm/distributed/kv_transfer/kv_connector/simple_connector.py

Lines changed: 21 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import torch
1414

15-
import vllm.envs as envs
16-
from vllm import _custom_ops as ops
1715
from vllm.config import VllmConfig
1816
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
17+
from vllm.distributed.kv_transfer.kv_connector.utils import (
18+
model_aware_kv_ops_helper as kv_helper)
1919
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
2020
SimpleBuffer)
2121
from vllm.logger import init_logger
@@ -37,9 +37,7 @@ def __init__(
3737
):
3838

3939
self.config = config.kv_transfer_config
40-
self.tp_size = config.parallel_config.tensor_parallel_size
41-
self.is_deepseek_mla = config.model_config.is_deepseek_mla
42-
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
40+
self.kv_helper = kv_helper(config)
4341

4442
if self.config.kv_connector == "PyNcclConnector":
4543
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
@@ -165,31 +163,7 @@ def send_kv_caches_and_hidden_states(
165163
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
166164
start_layer = model_executable.model.start_layer
167165
end_layer = model_executable.model.end_layer
168-
169-
model_config = model_executable.model.config
170-
num_heads = int(model_config.num_key_value_heads / self.tp_size)
171-
hidden_size = model_config.hidden_size
172-
num_attention_heads = model_config.num_attention_heads
173-
174-
# Deepseek's MLA (Multi-head Latent Attention) uses two different
175-
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
176-
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
177-
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
178-
# kv_lora_rank + qk_rope_head_dim].
179-
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
180-
# to a kv_cache shape of [2, num_blks, blk_size,
181-
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
182-
# For more details, see vllm/attention/backends/mla/common.py.
183-
if self.is_deepseek_mla and self.use_mla_opt:
184-
head_size = model_config.kv_lora_rank + \
185-
model_config.qk_rope_head_dim
186-
num_heads = 1
187-
elif self.is_deepseek_mla and not self.use_mla_opt:
188-
head_size = model_config.qk_nope_head_dim + \
189-
model_config.qk_rope_head_dim
190-
else:
191-
head_size = getattr(model_config, "head_dim",
192-
int(hidden_size // num_attention_heads))
166+
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
193167

194168
# query_lens contains new KV caches that are added to vLLM.
195169
# so we will send them to decode instance
@@ -212,13 +186,8 @@ def send_kv_caches_and_hidden_states(
212186

213187
for layer_id in range(start_layer, end_layer):
214188
kv_cache = kv_caches[layer_id - start_layer]
215-
216-
if self.is_deepseek_mla and self.use_mla_opt:
217-
key_cache = kv_cache.reshape(-1, num_heads, head_size)
218-
value_cache = kv_cache.reshape(-1, num_heads, head_size)
219-
else:
220-
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
221-
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
189+
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
190+
kv_cache, num_heads, head_size)
222191

223192
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
224193

@@ -248,12 +217,12 @@ def recv_kv_caches_and_hidden_states(
248217
# and hidden states.
249218
bypass_model_exec = True
250219

251-
model_config = model_executable.model.config
252-
253220
input_tokens_tensor = model_input.input_tokens
254221
seq_lens = model_input.attn_metadata.seq_lens
255222
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
256223
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
224+
start_layer = model_executable.model.start_layer
225+
end_layer = model_executable.model.end_layer
257226

258227
hidden_or_intermediate_states_for_one_req = []
259228

@@ -312,41 +281,19 @@ def recv_kv_caches_and_hidden_states(
312281
end_pos = start_pos + num_computed_tokens
313282

314283
# put received KV caches into paged memory
315-
for i in range(model_executable.model.start_layer,
316-
model_executable.model.end_layer):
317-
318-
kv_cache = kv_caches[i - model_executable.model.start_layer]
319-
layer = model_executable.model.layers[i]
320-
321-
if self.is_deepseek_mla and self.use_mla_opt:
322-
layer.self_attn.attn = layer.self_attn.mla_attn
323-
k_c_normed_k_pe = keys[
324-
i - model_executable.model.start_layer].to(
325-
kv_cache.device).squeeze(1)
326-
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
327-
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
328-
ops.concat_and_cache_mla(
329-
k_c_normed,
330-
k_pe,
331-
kv_cache,
332-
slot_mapping[start_pos:end_pos],
333-
layer.self_attn.attn.kv_cache_dtype,
334-
layer.self_attn.attn._k_scale,
335-
)
336-
else:
337-
key_cache, value_cache = kv_cache[0], kv_cache[1]
338-
ops.reshape_and_cache_flash(
339-
keys[i - model_executable.model.start_layer].to(
340-
key_cache.device),
341-
values[i - model_executable.model.start_layer].to(
342-
value_cache.device),
343-
key_cache,
344-
value_cache,
345-
slot_mapping[start_pos:end_pos],
346-
layer.self_attn.attn.kv_cache_dtype,
347-
layer.self_attn.attn._k_scale,
348-
layer.self_attn.attn._v_scale,
349-
)
284+
for cur_layer in range(start_layer, end_layer):
285+
286+
layer_id = cur_layer - start_layer
287+
kv_cache = kv_caches[layer_id]
288+
layer = model_executable.model.layers[cur_layer]
289+
290+
# get remote kvcache
291+
remote_k, remote_v = keys[layer_id], values[layer_id]
292+
293+
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
294+
remote_v, layer, kv_cache,
295+
slot_mapping, start_pos,
296+
end_pos)
350297

351298
hidden_or_intermediate_states_for_one_req.append(hidden)
352299

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
KV cache helper for store.
4+
"""
5+
import torch
6+
7+
import vllm.envs as envs
8+
from vllm import _custom_ops as ops
9+
from vllm.config import VllmConfig
10+
from vllm.logger import init_logger
11+
12+
logger = init_logger(__name__)
13+
14+
15+
class model_aware_kv_ops_helper:
16+
17+
def __init__(self, config: VllmConfig):
18+
self.is_deepseek_mla = config.model_config.is_deepseek_mla
19+
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
20+
self.tp_size = config.parallel_config.tensor_parallel_size
21+
22+
def get_model_args(self, model_executable: torch.nn.Module):
23+
24+
model_config = model_executable.model.config
25+
self.model_executable = model_executable
26+
num_heads = int(model_config.num_key_value_heads / self.tp_size)
27+
hidden_size = model_config.hidden_size
28+
num_attention_heads = model_config.num_attention_heads
29+
30+
# Deepseek's MLA (Multi-head Latent Attention) uses two different
31+
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
32+
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
33+
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
34+
# kv_lora_rank + qk_rope_head_dim].
35+
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
36+
# to a kv_cache shape of [2, num_blks, blk_size,
37+
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
38+
# For more details, see vllm/attention/backends/mla/common.py.
39+
if self.is_deepseek_mla and self.use_mla_opt:
40+
head_size = model_config.kv_lora_rank + \
41+
model_config.qk_rope_head_dim
42+
num_heads = 1
43+
elif self.is_deepseek_mla and not self.use_mla_opt:
44+
head_size = model_config.qk_nope_head_dim + \
45+
model_config.qk_rope_head_dim
46+
else:
47+
head_size = getattr(model_config, "head_dim",
48+
int(hidden_size // num_attention_heads))
49+
50+
return num_heads, head_size
51+
52+
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
53+
if self.is_deepseek_mla and self.use_mla_opt:
54+
key_cache = kv_cache.reshape(-1, num_heads, head_size)
55+
value_cache = kv_cache.reshape(-1, num_heads, head_size)
56+
else:
57+
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
58+
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
59+
return key_cache, value_cache
60+
61+
def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
62+
layer, kv_cache, slot_mapping, start_pos, end_pos):
63+
64+
model_config = model_executable.model.config
65+
66+
if self.is_deepseek_mla and self.use_mla_opt:
67+
layer.self_attn.attn = layer.self_attn.mla_attn
68+
k_c_normed_k_pe = keys.squeeze(1)
69+
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
70+
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
71+
ops.concat_and_cache_mla(
72+
k_c_normed.to(kv_cache.device),
73+
k_pe.to(kv_cache.device),
74+
kv_cache,
75+
slot_mapping[start_pos:end_pos],
76+
layer.self_attn.attn.kv_cache_dtype,
77+
layer.self_attn.attn._k_scale,
78+
)
79+
else:
80+
key_cache, value_cache = kv_cache[0], kv_cache[1]
81+
ops.reshape_and_cache_flash(
82+
keys.to(key_cache.device),
83+
values.to(value_cache.device),
84+
key_cache,
85+
value_cache,
86+
slot_mapping[start_pos:end_pos],
87+
layer.self_attn.attn.kv_cache_dtype,
88+
layer.self_attn.attn._k_scale,
89+
layer.self_attn.attn._v_scale,
90+
)

0 commit comments

Comments
 (0)