Skip to content

Commit c4d3f2f

Browse files
address comments
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent c5f769b commit c4d3f2f

File tree

6 files changed

+53
-52
lines changed

6 files changed

+53
-52
lines changed

vllm/attention/backends/abstract.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class AttentionType:
3030
ENCODER_ONLY = "encoder_only"
3131
# Attention between dec. Q and enc. K/V for encoder-decoder
3232
ENCODER_DECODER = "encoder_decoder"
33-
# Attention layer that reuse kv cache
34-
DECODER_DECODER = "decoder_decoder"
3533

3634

3735
class AttentionBackend(ABC):

vllm/attention/backends/differential_flash_attn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
"""Attention layer with FlashAttention."""
43
from collections import defaultdict
54
from dataclasses import dataclass
65
from itertools import accumulate
@@ -55,7 +54,6 @@ def get_kv_cache_shape(
5554
) -> Tuple[int, ...]:
5655
if block_size % 16 != 0:
5756
raise ValueError("Block size must be a multiple of 16.")
58-
# return (2, num_blocks, block_size, num_kv_heads, head_size)
5957
return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
6058

6159
@staticmethod
@@ -634,8 +632,9 @@ def __init__(
634632
self.differential_flash_attention_config = differential_flash_attention_config
635633
self.used_shared_kv_cache = self.differential_flash_attention_config.get(
636634
"used_shared_kv_cache", False)
637-
if kv_sharing_target_layer_name is not None:
638-
raise NotImplementedError("KV sharing is not supported in V0.")
635+
# if kv_sharing_target_layer_name is not None:
636+
# raise NotImplementedError("KV sharing is not supported in V0.")
637+
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
639638
if blocksparse_params is not None:
640639
raise ValueError(
641640
"FlashAttention does not support block-sparse attention.")

vllm/attention/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def __init__(
160160
self.attn_type = attn_type
161161

162162
if kv_sharing_target_layer_name is not None:
163-
if not envs.VLLM_USE_V1:
164-
raise NotImplementedError(
165-
"Cross-layer KV sharing is not supported in V0.")
163+
# if not envs.VLLM_USE_V1:
164+
# raise NotImplementedError(
165+
# "Cross-layer KV sharing is not supported in V0.")
166166

167167
validate_kv_sharing_target(
168168
prefix,

vllm/model_executor/models/phi4flash.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,13 @@ def __init__(self,
138138
"subln": self.subln,
139139
}
140140
}
141-
141+
142+
if yoco_cross:
143+
kv_shared_layer_index = config.num_hidden_layers//2 + 1
144+
kv_sharing_target_layer_name = f"model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
145+
else:
146+
kv_sharing_target_layer_name = None
147+
142148
self.attn = Attention(
143149
self.num_heads,
144150
self.head_dim,
@@ -147,7 +153,8 @@ def __init__(self,
147153
cache_config=cache_config,
148154
per_layer_sliding_window=sliding_window,
149155
prefix=f"{prefix}.attn",
150-
attn_type=AttentionType.DECODER_DECODER if self.yoco_cross else AttentionType.DECODER,
156+
attn_type=AttentionType.DECODER,
157+
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
151158
**params
152159
)
153160

@@ -157,9 +164,6 @@ def lambda_init_fn(self, depth):
157164
def forward(
158165
self,
159166
hidden_states: torch.Tensor,
160-
positions: torch.Tensor,
161-
kv_cache: torch.Tensor,
162-
attn_metadata: AttentionMetadata,
163167
):
164168

165169
if not self.yoco_cross: # need to generate kv-cache
@@ -168,9 +172,6 @@ def forward(
168172
attn_output = self.attn(q, k, v)
169173
else: # re-use the kv cache, full attention
170174
q = self.Wqkv(hidden_states)
171-
virtual_engine = get_virtual_engine()
172-
if self.attn.kv_cache[virtual_engine].numel() == 0:
173-
self.attn.kv_cache[virtual_engine] = kv_cache
174175
attn_output = self.attn(q, None, None)
175176
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
176177
return self.out_proj(attn_output)
@@ -417,15 +418,14 @@ def forward(
417418
self,
418419
hidden_states: torch.Tensor,
419420
positions: torch.Tensor,
420-
kv_cache: torch.Tensor,
421421
attn_metadata: AttentionMetadata,
422422
mamba_cache_params: MambaCacheParams,
423423
ssm_output: Optional[torch.LongTensor] = None,
424424
) -> Union[torch.Tensor, IntermediateTensors]:
425425
if self.use_mamba:
426-
assert kv_cache is None and mamba_cache_params is not None
426+
assert mamba_cache_params is not None
427427
else:
428-
assert kv_cache is not None and mamba_cache_params is None
428+
assert mamba_cache_params is None
429429

430430
residual = hidden_states
431431
hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype))
@@ -441,9 +441,6 @@ def forward(
441441
else:
442442
attn_outputs = self.attn(
443443
hidden_states,
444-
positions,
445-
kv_cache,
446-
attn_metadata,
447444
)
448445
hidden_states = residual + attn_outputs
449446
residual = hidden_states
@@ -452,12 +449,7 @@ def forward(
452449
hidden_states = residual + hidden_states
453450

454451
return hidden_states, ssm_output
455-
456-
def get_kv_cache(layer_name):
457-
forward_context: ForwardContext = get_forward_context()
458-
self = forward_context.no_compile_layers[layer_name]
459-
kv_cache = self.kv_cache[forward_context.virtual_engine]
460-
return kv_cache
452+
461453

462454
class SambaYModel(nn.Module):
463455

@@ -513,16 +505,16 @@ def forward(
513505
assert intermediate_tensors is not None
514506
hidden_states = intermediate_tensors["hidden_states"]
515507

516-
kv_cache_idx = 0
517508
mamba_state_idx = 0
518509
ssm_output = None
519510
for i in range(self.start_layer, self.end_layer):
520511
layer = self.layers[i]
521512
if i == self.config.num_hidden_layers // 2 + 2:
522513
# profile run
514+
kv_cache_idx = self.config.num_hidden_layers//2 + 1
523515
cache_layer = self.layers[kv_cache_idx]
524-
kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name)
525-
if kv_cache.numel() == 0:
516+
kv_cache = cache_layer.attn.attn.kv_cache
517+
if kv_cache[0].numel() == 0:
526518
break
527519

528520
# Starting from this layer, we do not need to cuculate the kv cache since we reuse
@@ -546,31 +538,14 @@ def forward(
546538
hidden_states, ssm_output = layer(
547539
hidden_states,
548540
positions,
549-
None, # kv_cache
550541
attn_metadata,
551542
mamba_cache,
552543
ssm_output = ssm_output
553544
)
554545
else:
555-
if i < self.config.num_hidden_layers // 2:
556-
# sliding window attention
557-
cache_layer = self.layers[i]
558-
kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name)
559-
kv_cache_idx = i
560-
elif not layer.yoco_cross:
561-
# full attention that generates kv cache
562-
cache_layer = self.layers[i]
563-
kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name)
564-
kv_cache_idx = i
565-
else:
566-
# full attention that reuses kv cache
567-
cache_layer = self.layers[kv_cache_idx]
568-
kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name)
569-
570546
hidden_states, ssm_output = layer(
571547
hidden_states,
572548
positions,
573-
kv_cache,
574549
attn_metadata,
575550
None, # mamba_cache_params
576551
ssm_output = ssm_output

vllm/utils/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2881,6 +2881,7 @@ def get_mp_context():
28812881
def bind_kv_cache(
28822882
ctx: dict[str, Any],
28832883
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
2884+
shared_kv_cache_layers: dict[str, str],
28842885
) -> None:
28852886
# Bind the kv_cache tensor to Attention modules, similar to
28862887
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
@@ -2904,11 +2905,16 @@ def bind_kv_cache(
29042905
extract_layer_index(layer_name)
29052906
for layer_name in layer_need_kv_cache))
29062907
for layer_name in layer_need_kv_cache:
2908+
target_layer_name = shared_kv_cache_layers[layer_name] if layer_name \
2909+
in shared_kv_cache_layers else layer_name
29072910
kv_cache_idx = layer_index_sorted.index(
2908-
extract_layer_index(layer_name))
2911+
extract_layer_index(target_layer_name))
29092912
forward_ctx = ctx[layer_name]
29102913
assert len(forward_ctx.kv_cache) == len(kv_cache)
2914+
29112915
for ve, ve_kv_cache in enumerate(kv_cache):
2916+
assert kv_cache_idx < len(ve_kv_cache), \
2917+
"v0 doesn't support interleaving kv sharing, use v1 instead"
29122918
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
29132919

29142920

vllm/worker/worker.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch.distributed
1010

1111
import vllm.envs as envs
12-
from vllm.config import VllmConfig
12+
from vllm.attention.layer import Attention
13+
from vllm.config import (VllmConfig, get_layers_from_vllm_config)
1314
from vllm.device_allocator.cumem import CuMemAllocator
1415
from vllm.distributed import (ensure_model_parallel_initialized,
1516
init_distributed_environment,
@@ -26,6 +27,7 @@
2627
SequenceGroupMetadata, SequenceGroupMetadataDelta)
2728
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
2829
memory_profiling)
30+
from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing
2931
from vllm.worker.cache_engine import CacheEngine
3032
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
3133
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
@@ -345,8 +347,29 @@ def _init_cache_engine(self):
345347
self.cache_engine[ve].gpu_cache
346348
for ve in range(self.parallel_config.pipeline_parallel_size)
347349
]
350+
351+
# Layer pairings for cross-layer KV sharing.
352+
# If an Attention layer `layer_name` is in the keys of this dict, it
353+
# means this layer will perform attention using the keys and values
354+
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
355+
shared_kv_cache_layers: dict[str, str] = {}
356+
357+
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
358+
359+
for layer_name, attn_module in attn_layers.items():
360+
if (kv_tgt_layer :=
361+
attn_module.kv_sharing_target_layer_name) is not None:
362+
# The layer doesn't need its own KV cache and will use that of
363+
# the target layer. We skip creating a KVCacheSpec for it, so
364+
# that KV cache management logic will act as this layer does
365+
# not exist, and doesn't allocate KV cache for the layer. This
366+
# enables the memory saving of cross-layer kv sharing, allowing
367+
# a given amount of memory to accommodate longer context lengths
368+
# or enable more requests to be processed simultaneously.
369+
shared_kv_cache_layers[layer_name] = kv_tgt_layer
370+
348371
bind_kv_cache(self.compilation_config.static_forward_context,
349-
self.gpu_cache)
372+
self.gpu_cache, shared_kv_cache_layers)
350373

351374
def _warm_up_model(self) -> None:
352375
# warm up sizes that are not in cudagraph capture sizes,

0 commit comments

Comments
 (0)