|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 | +"""" An implementation of https://arxiv.org/pdf/2410.05258 """ |
3 | 4 | from collections import defaultdict
|
4 | 5 | from dataclasses import dataclass
|
5 | 6 | from itertools import accumulate
|
|
11 | 12 | from vllm import _custom_ops as ops
|
12 | 13 | # yapf conflicts with isort for this block
|
13 | 14 | # yapf: disable
|
14 |
| -from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, |
| 15 | +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, |
| 16 | + AttentionLayer, |
15 | 17 | AttentionMetadata,
|
16 | 18 | AttentionMetadataBuilder,
|
17 | 19 | AttentionType,
|
18 | 20 | is_quantized_kv_cache)
|
19 | 21 | from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
20 | 22 | # yapf: enable
|
21 |
| -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, |
| 23 | +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, |
| 24 | + compute_slot_mapping, |
22 | 25 | compute_slot_mapping_start_idx,
|
23 | 26 | is_all_cross_attn_metadata_set,
|
24 | 27 | is_all_encoder_attn_metadata_set,
|
|
38 | 41 | logger = init_logger(__name__)
|
39 | 42 |
|
40 | 43 |
|
41 |
| -class DifferentialFlashAttentionBackend(FlashAttentionBackend): |
| 44 | +class DifferentialFlashAttentionBackend(AttentionBackend): |
42 | 45 | accept_output_buffer = False
|
43 | 46 |
|
| 47 | + @staticmethod |
| 48 | + def get_supported_head_sizes() -> List[int]: |
| 49 | + return [32, 64, 96, 128, 160, 192, 224, 256] |
| 50 | + |
44 | 51 | @staticmethod
|
45 | 52 | def get_kv_cache_shape(
|
46 | 53 | num_blocks: int,
|
@@ -69,6 +76,33 @@ def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]:
|
69 | 76 | def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]:
|
70 | 77 | return DifferentialFlashAttentionMetadataBuilder
|
71 | 78 |
|
| 79 | + @staticmethod |
| 80 | + def get_state_cls() -> Type["CommonAttentionState"]: |
| 81 | + return CommonAttentionState |
| 82 | + |
| 83 | + @staticmethod |
| 84 | + def swap_blocks( |
| 85 | + src_kv_cache: torch.Tensor, |
| 86 | + dst_kv_cache: torch.Tensor, |
| 87 | + src_to_dst: torch.Tensor, |
| 88 | + ) -> None: |
| 89 | + src_key_cache = src_kv_cache[0] |
| 90 | + dst_key_cache = dst_kv_cache[0] |
| 91 | + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) |
| 92 | + src_value_cache = src_kv_cache[1] |
| 93 | + dst_value_cache = dst_kv_cache[1] |
| 94 | + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) |
| 95 | + |
| 96 | + @staticmethod |
| 97 | + def copy_blocks( |
| 98 | + kv_caches: List[torch.Tensor], |
| 99 | + src_to_dists: torch.Tensor, |
| 100 | + ) -> None: |
| 101 | + key_caches = [kv_cache[0] for kv_cache in kv_caches] |
| 102 | + value_caches = [kv_cache[1] for kv_cache in kv_caches] |
| 103 | + |
| 104 | + ops.copy_blocks(key_caches, value_caches, src_to_dists) |
| 105 | + |
72 | 106 |
|
73 | 107 | @dataclass
|
74 | 108 | class DifferentialFlashAttentionMetadata(AttentionMetadata):
|
@@ -635,6 +669,8 @@ def __init__(
|
635 | 669 | use_irope: bool = False,
|
636 | 670 | differential_flash_attention_config: Optional[Dict[str, Any]] = None,
|
637 | 671 | ) -> None:
|
| 672 | + if differential_flash_attention_config is None: |
| 673 | + differential_flash_attention_config = {} |
638 | 674 | self.differential_flash_attention_config = \
|
639 | 675 | differential_flash_attention_config
|
640 | 676 | self.used_shared_kv_cache = \
|
@@ -722,7 +758,7 @@ def forward_generate_kv_cache(
|
722 | 758 | self, query: torch.Tensor, key: Optional[torch.Tensor],
|
723 | 759 | value: Optional[torch.Tensor], k_cache: torch.Tensor,
|
724 | 760 | v_cache: torch.Tensor,
|
725 |
| - attn_metadata: AttentionMetadata) -> torch.Tensor: |
| 761 | + attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor: |
726 | 762 |
|
727 | 763 | head_size = self.head_size
|
728 | 764 | num_heads = self.num_heads // 2
|
@@ -758,7 +794,9 @@ def forward_generate_kv_cache(
|
758 | 794 |
|
759 | 795 | if prefill_meta := attn_metadata.prefill_metadata:
|
760 | 796 | # Prompt run.
|
761 |
| - if k_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: |
| 797 | + if k_cache.numel() == 0 \ |
| 798 | + or prefill_meta.block_tables is None \ |
| 799 | + or prefill_meta.block_tables.numel() == 0: |
762 | 800 | # normal attention
|
763 | 801 | prefill_output = flash_attn_varlen_func(
|
764 | 802 | q=query,
|
@@ -808,7 +846,7 @@ def forward_with_kv_cache_only(
|
808 | 846 | query: torch.Tensor,
|
809 | 847 | k_cache: torch.Tensor,
|
810 | 848 | v_cache: torch.Tensor,
|
811 |
| - attn_metadata: AttentionMetadata, |
| 849 | + attn_metadata: DifferentialFlashAttentionMetadata, |
812 | 850 | ):
|
813 | 851 | if not attn_metadata.decode_metadata:
|
814 | 852 | block_tables_arg = attn_metadata.cross_layer_shared_block_tables
|
@@ -838,6 +876,7 @@ def forward(
|
838 | 876 | kv_cache: torch.Tensor,
|
839 | 877 | attn_metadata: DifferentialFlashAttentionMetadata,
|
840 | 878 | output: Optional[torch.Tensor] = None,
|
| 879 | + output_scale: Optional[torch.Tensor] = None, |
841 | 880 | ) -> torch.Tensor:
|
842 | 881 | """Forward pass with FlashAttention.
|
843 | 882 |
|
|
0 commit comments