Skip to content

Commit 88cb796

Browse files
address comments
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent bc52add commit 88cb796

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

vllm/attention/backends/differential_flash_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -673,9 +673,7 @@ def __init__(
673673
differential_flash_attention_config = {}
674674
self.differential_flash_attention_config = \
675675
differential_flash_attention_config
676-
self.used_shared_kv_cache = \
677-
self.differential_flash_attention_config.get(
678-
"used_shared_kv_cache", False)
676+
self.used_shared_kv_cache = kv_sharing_target_layer_name is not None
679677
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
680678
if blocksparse_params is not None:
681679
raise ValueError(

vllm/model_executor/models/phi4flash.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def __init__(self,
147147

148148
params = {
149149
'differential_flash_attention_config': {
150-
'used_shared_kv_cache': self.yoco_cross,
151150
'lambda_init': self.lambda_init,
152151
'lambda_q1': self.lambda_q1,
153152
'lambda_k1': self.lambda_k1,
@@ -661,6 +660,8 @@ def forward(
661660
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
662661

663662
attn_metadata = get_forward_context().attn_metadata
663+
# input_ids and hidden_states isn't a one-to-one mapping in prefill
664+
# stage due to YOCO optimization.
664665
hidden_states = self.model(input_ids, positions, attn_metadata,
665666
mamba_cache_params, intermediate_tensors,
666667
inputs_embeds)

0 commit comments

Comments
 (0)