Skip to content

Commit ab802f3

Browse files
minor
1 parent 349e17a commit ab802f3

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

vllm/attention/backends/differential_flash_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def get_kv_cache_shape(
5454
) -> Tuple[int, ...]:
5555
if block_size % 16 != 0:
5656
raise ValueError("Block size must be a multiple of 16.")
57+
assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2"
5758
return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
5859

5960
@staticmethod
@@ -872,7 +873,7 @@ def forward(
872873
k1, k2 = self.split_heads(k)
873874
v1, v2 = self.split_heads(v)
874875

875-
# kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size)
876+
# kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
876877
# Split by half along the first dimension.
877878
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
878879
assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous"
@@ -909,7 +910,7 @@ def forward(
909910
else: # re-use the kv cache, full attention
910911
q = q.view(-1, self.num_heads, self.head_size)
911912
q1, q2 = self.split_heads(q)
912-
# kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size)
913+
# kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size)
913914
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
914915
key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1]
915916
key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1]

vllm/model_executor/models/phi4flash.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,9 @@ def forward(
517517
if kv_cache[0].numel() == 0:
518518
break
519519

520-
# Starting from this layer, we do not need to cuculate the kv cache since we reuse
521-
# the kv cache from last layer. If in prefill phase, we can prune truncate
522-
# hidden state to save computation cost.
520+
# Starting from this layer, we do not need to calculate the kv cache since we reuse
521+
# the kv cache from last layer. If in prefill phase, we can <s>prune></s> truncate
522+
# the hidden state to save computation cost.
523523
if attn_metadata.prefill_metadata:
524524
selected_token_indices = torch.cumsum(attn_metadata.seq_lens_tensor, dim=0) - 1
525525
hidden_states = hidden_states.index_select(0, selected_token_indices)

0 commit comments

Comments
 (0)