Skip to content

Commit 0e20e17

Browse files
address lint
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent bb83dcb commit 0e20e17

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed

vllm/attention/backends/differential_flash_attn.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""" An implementation of https://arxiv.org/pdf/2410.05258 """
34
from collections import defaultdict
45
from dataclasses import dataclass
56
from itertools import accumulate
@@ -11,14 +12,16 @@
1112
from vllm import _custom_ops as ops
1213
# yapf conflicts with isort for this block
1314
# yapf: disable
14-
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
15+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
16+
AttentionLayer,
1517
AttentionMetadata,
1618
AttentionMetadataBuilder,
1719
AttentionType,
1820
is_quantized_kv_cache)
1921
from vllm.attention.backends.flash_attn import FlashAttentionBackend
2022
# yapf: enable
21-
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
23+
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
24+
compute_slot_mapping,
2225
compute_slot_mapping_start_idx,
2326
is_all_cross_attn_metadata_set,
2427
is_all_encoder_attn_metadata_set,
@@ -38,9 +41,13 @@
3841
logger = init_logger(__name__)
3942

4043

41-
class DifferentialFlashAttentionBackend(FlashAttentionBackend):
44+
class DifferentialFlashAttentionBackend(AttentionBackend):
4245
accept_output_buffer = False
4346

47+
@staticmethod
48+
def get_supported_head_sizes() -> List[int]:
49+
return [32, 64, 96, 128, 160, 192, 224, 256]
50+
4451
@staticmethod
4552
def get_kv_cache_shape(
4653
num_blocks: int,
@@ -69,6 +76,33 @@ def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]:
6976
def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]:
7077
return DifferentialFlashAttentionMetadataBuilder
7178

79+
@staticmethod
80+
def get_state_cls() -> Type["CommonAttentionState"]:
81+
return CommonAttentionState
82+
83+
@staticmethod
84+
def swap_blocks(
85+
src_kv_cache: torch.Tensor,
86+
dst_kv_cache: torch.Tensor,
87+
src_to_dst: torch.Tensor,
88+
) -> None:
89+
src_key_cache = src_kv_cache[0]
90+
dst_key_cache = dst_kv_cache[0]
91+
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
92+
src_value_cache = src_kv_cache[1]
93+
dst_value_cache = dst_kv_cache[1]
94+
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
95+
96+
@staticmethod
97+
def copy_blocks(
98+
kv_caches: List[torch.Tensor],
99+
src_to_dists: torch.Tensor,
100+
) -> None:
101+
key_caches = [kv_cache[0] for kv_cache in kv_caches]
102+
value_caches = [kv_cache[1] for kv_cache in kv_caches]
103+
104+
ops.copy_blocks(key_caches, value_caches, src_to_dists)
105+
72106

73107
@dataclass
74108
class DifferentialFlashAttentionMetadata(AttentionMetadata):
@@ -635,6 +669,8 @@ def __init__(
635669
use_irope: bool = False,
636670
differential_flash_attention_config: Optional[Dict[str, Any]] = None,
637671
) -> None:
672+
if differential_flash_attention_config is None:
673+
differential_flash_attention_config = {}
638674
self.differential_flash_attention_config = \
639675
differential_flash_attention_config
640676
self.used_shared_kv_cache = \
@@ -722,7 +758,7 @@ def forward_generate_kv_cache(
722758
self, query: torch.Tensor, key: Optional[torch.Tensor],
723759
value: Optional[torch.Tensor], k_cache: torch.Tensor,
724760
v_cache: torch.Tensor,
725-
attn_metadata: AttentionMetadata) -> torch.Tensor:
761+
attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor:
726762

727763
head_size = self.head_size
728764
num_heads = self.num_heads // 2
@@ -758,7 +794,9 @@ def forward_generate_kv_cache(
758794

759795
if prefill_meta := attn_metadata.prefill_metadata:
760796
# Prompt run.
761-
if k_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
797+
if k_cache.numel() == 0 \
798+
or prefill_meta.block_tables is None \
799+
or prefill_meta.block_tables.numel() == 0:
762800
# normal attention
763801
prefill_output = flash_attn_varlen_func(
764802
q=query,
@@ -808,7 +846,7 @@ def forward_with_kv_cache_only(
808846
query: torch.Tensor,
809847
k_cache: torch.Tensor,
810848
v_cache: torch.Tensor,
811-
attn_metadata: AttentionMetadata,
849+
attn_metadata: DifferentialFlashAttentionMetadata,
812850
):
813851
if not attn_metadata.decode_metadata:
814852
block_tables_arg = attn_metadata.cross_layer_shared_block_tables
@@ -838,6 +876,7 @@ def forward(
838876
kv_cache: torch.Tensor,
839877
attn_metadata: DifferentialFlashAttentionMetadata,
840878
output: Optional[torch.Tensor] = None,
879+
output_scale: Optional[torch.Tensor] = None,
841880
) -> torch.Tensor:
842881
"""Forward pass with FlashAttention.
843882

0 commit comments

Comments
 (0)