Skip to content

Commit 11e30fe

Browse files
address comment
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 875e85b commit 11e30fe

File tree

13 files changed

+60
-17
lines changed

13 files changed

+60
-17
lines changed

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ def check_available_online(
246246
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
247247
trust_remote_code=True,
248248
v0_only=True),
249+
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
250+
trust_remote_code=True),
249251
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
250252
trust_remote_code=True),
251253
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",

tests/models/test_initialization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def _initialize_kv_caches_v1(self, vllm_config):
8686
_initialize_kv_caches_v1), monkeypatch.context() as m):
8787
if model_info.v0_only:
8888
m.setenv("VLLM_USE_V1", "0")
89+
if model_arch == "Phi4FlashForCausalLM":
90+
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
91+
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
8992
LLM(
9093
model_info.default,
9194
tokenizer=model_info.tokenizer,

tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,31 @@ def test_bind_kv_cache():
458458
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
459459
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
460460

461+
def test_bind_kv_cache_kv_sharing():
462+
from vllm.attention import Attention
463+
464+
ctx = {
465+
'layers.0.self_attn': Attention(32, 128, 0.1),
466+
'layers.1.self_attn': Attention(32, 128, 0.1),
467+
'layers.2.self_attn': Attention(32, 128, 0.1),
468+
'layers.3.self_attn': Attention(32, 128, 0.1),
469+
}
470+
kv_cache = [
471+
torch.zeros((1, )),
472+
torch.zeros((1, )),
473+
torch.zeros((1, )),
474+
torch.zeros((1, )),
475+
]
476+
shared_kv_cache_layers = {
477+
'layers.2.self_attn': 'layers.1.self_attn',
478+
'layers.3.self_attn': 'layers.0.self_attn'
479+
}
480+
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
481+
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
482+
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
483+
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
484+
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]
485+
461486
def test_bind_kv_cache_non_attention():
462487
from vllm.attention import Attention
463488

vllm/attention/backends/blocksparse_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ def __init__(
308308
kv_sharing_target_layer_name: Optional[str] = None,
309309
) -> None:
310310
if kv_sharing_target_layer_name is not None:
311-
raise NotImplementedError("KV sharing is not supported in V0.")
311+
raise NotImplementedError("KV sharing is not supported in V0 "
312+
"BLOCK_SPARSE_FLASH_ATTN Backend.")
312313
assert blocksparse_params is not None
313314
assert alibi_slopes is None, ValueError(
314315
"Alibi not support for blocksparse flash attention.")

vllm/attention/backends/differential_flash_attn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,6 @@ def __init__(
676676
self.used_shared_kv_cache = \
677677
self.differential_flash_attention_config.get(
678678
"used_shared_kv_cache", False)
679-
# if kv_sharing_target_layer_name is not None:
680-
# raise NotImplementedError("KV sharing is not supported in V0.")
681679
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
682680
if blocksparse_params is not None:
683681
raise ValueError(

vllm/attention/backends/dual_chunk_flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ def __init__(
295295
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
296296
) -> None:
297297
if kv_sharing_target_layer_name is not None:
298-
raise NotImplementedError("KV sharing is not supported in V0.")
298+
raise NotImplementedError("KV sharing is not supported in V0 "
299+
"DUAL_CHUNK_FLASH_ATTN backend.")
299300
self.num_heads = num_heads
300301
self.head_size = head_size
301302
self.scale = float(scale)

vllm/attention/backends/flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,8 @@ def __init__(
622622
use_irope: bool = False,
623623
) -> None:
624624
if kv_sharing_target_layer_name is not None:
625-
raise NotImplementedError("KV sharing is not supported in V0.")
625+
raise NotImplementedError("KV sharing is not supported in V0 "
626+
"FLASH_ATTN backend.")
626627
if blocksparse_params is not None:
627628
raise ValueError(
628629
"FlashAttention does not support block-sparse attention.")

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,8 @@ def __init__(
938938
use_irope: bool = False,
939939
) -> None:
940940
if kv_sharing_target_layer_name is not None:
941-
raise NotImplementedError("KV sharing is not supported in V0.")
941+
raise NotImplementedError("KV sharing is not supported in V0 "
942+
"FLASHINFER backend.")
942943
if use_irope:
943944
logger.warning_once(
944945
"Using irope in FlashInfer is not supported yet, it will fall"

vllm/attention/backends/hpu_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def __init__(
115115
) -> None:
116116
super(AttentionImpl, self).__init__()
117117
if kv_sharing_target_layer_name is not None:
118-
raise NotImplementedError("KV sharing is not supported in V0.")
118+
raise NotImplementedError("KV sharing is not supported in V0 "
119+
"HPU_ATTN backend.")
119120
if use_irope:
120121
logger.warning_once(
121122
"Using irope in HPU is not supported yet, it will fall back "

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,8 @@ def __init__(
499499
use_irope: bool = False,
500500
) -> None:
501501
if kv_sharing_target_layer_name is not None:
502-
raise NotImplementedError("KV sharing is not supported in V0.")
502+
raise NotImplementedError("KV sharing is not supported in V0 "
503+
"ROCM_FLASH backend.")
503504
if use_irope:
504505
logger.warning_once(
505506
"Using irope in ROCm Flash Attention is not supported yet, it "

0 commit comments

Comments
 (0)