|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import math |
| 5 | + |
| 6 | +import torch |
| 7 | +from xformers.ops.fmha.attn_bias import PagedBlockDiagonalPaddedKeysMask |
| 8 | + |
| 9 | +from vllm.attention.backends.abstract import AttentionBackend |
| 10 | +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend |
| 11 | +from vllm.v1.attention.backends.tree_attn import TreeAttentionBackend |
| 12 | + |
| 13 | + |
| 14 | +class NoOpLayerModule(torch.nn.Module): |
| 15 | + _q_scale = torch.tensor(1.0, dtype=torch.float32) |
| 16 | + _k_scale = torch.tensor(1.0, dtype=torch.float32) |
| 17 | + _v_scale = torch.tensor(1.0, dtype=torch.float32) |
| 18 | + |
| 19 | + def __init__(self): |
| 20 | + super().__init__() |
| 21 | + |
| 22 | + def forward(self, x): |
| 23 | + return x |
| 24 | + |
| 25 | + |
| 26 | +def forward_attention( |
| 27 | + batch_size: int, |
| 28 | + num_heads: int, |
| 29 | + num_kv_heads: int, |
| 30 | + dim_per_head: int, |
| 31 | + block_size: int, |
| 32 | + max_sequence_length: int, |
| 33 | + sequence_position: int, |
| 34 | + q_len: int, |
| 35 | + backends: list[type[AttentionBackend]], |
| 36 | +) -> list[torch.Tensor]: |
| 37 | + # Assert that the number of heads is divisible by the number of KV heads. |
| 38 | + assert num_heads % num_kv_heads == 0 |
| 39 | + |
| 40 | + device = "cuda" |
| 41 | + # Initialize q, k, and v. |
| 42 | + q = torch.randn( |
| 43 | + (batch_size * q_len, num_heads, dim_per_head), |
| 44 | + device=device, |
| 45 | + dtype=torch.bfloat16, |
| 46 | + ) |
| 47 | + k = torch.randn( |
| 48 | + (batch_size * q_len, num_kv_heads, dim_per_head), |
| 49 | + device=device, |
| 50 | + dtype=torch.bfloat16, |
| 51 | + ) |
| 52 | + v = torch.randn( |
| 53 | + (batch_size * q_len, num_kv_heads, dim_per_head), |
| 54 | + device=device, |
| 55 | + dtype=torch.bfloat16, |
| 56 | + ) |
| 57 | + |
| 58 | + # Initialize the query and KV sequence lengths. |
| 59 | + cu_seqlens_q = q_len * torch.arange( |
| 60 | + batch_size + 1, device=device, dtype=torch.int32) |
| 61 | + seqlens_q = torch.diff(cu_seqlens_q) |
| 62 | + seqlens_kv = torch.full( |
| 63 | + (batch_size, ), |
| 64 | + sequence_position + q_len, |
| 65 | + device=device, |
| 66 | + dtype=torch.int32, |
| 67 | + ) |
| 68 | + max_seqlen_q = q_len |
| 69 | + max_seqlen_k = sequence_position + q_len |
| 70 | + num_actual_tokens = cu_seqlens_q[-1] |
| 71 | + |
| 72 | + # Setup the block table and KV cache for paged KV. |
| 73 | + assert max_sequence_length % block_size == 0 |
| 74 | + max_block_count_per_batch = max_sequence_length // block_size |
| 75 | + kv_cache = torch.randn( |
| 76 | + ( |
| 77 | + 2, |
| 78 | + batch_size * max_block_count_per_batch, |
| 79 | + block_size, |
| 80 | + num_kv_heads, |
| 81 | + dim_per_head, |
| 82 | + ), |
| 83 | + device=device, |
| 84 | + dtype=torch.bfloat16, |
| 85 | + ) |
| 86 | + num_allocated_blocks_per_batch = math.ceil(max_seqlen_k / block_size) |
| 87 | + block_table = torch.zeros( |
| 88 | + (batch_size, max_block_count_per_batch), |
| 89 | + device=device, |
| 90 | + dtype=torch.int32, |
| 91 | + ) |
| 92 | + block_ids = torch.arange( |
| 93 | + 0, |
| 94 | + batch_size * num_allocated_blocks_per_batch, |
| 95 | + device=device, |
| 96 | + dtype=torch.int32, |
| 97 | + ).view(-1, num_allocated_blocks_per_batch) |
| 98 | + block_table[:, :num_allocated_blocks_per_batch] = block_ids |
| 99 | + |
| 100 | + # Setup the slot mapping for the input KVs. |
| 101 | + slots_per_batch = [] |
| 102 | + for i in range(batch_size): |
| 103 | + start_offset = block_ids[i, 0] * block_size + sequence_position |
| 104 | + slots_per_batch.append( |
| 105 | + torch.arange( |
| 106 | + start_offset, |
| 107 | + start_offset + q_len, |
| 108 | + device=device, |
| 109 | + dtype=torch.int64, |
| 110 | + )) |
| 111 | + slot_mapping = torch.cat(slots_per_batch, dim=0) |
| 112 | + |
| 113 | + softmax_scale = q.shape[-1]**(-0.5) |
| 114 | + layer = NoOpLayerModule() |
| 115 | + |
| 116 | + # Run attention for each backend and collect the outputs. |
| 117 | + outputs = [] |
| 118 | + for backend_cls in backends: |
| 119 | + # Set common metadata. |
| 120 | + attn_metadata_dict = { |
| 121 | + "num_actual_tokens": num_actual_tokens, |
| 122 | + "max_query_len": max_seqlen_q, |
| 123 | + "query_start_loc": cu_seqlens_q, |
| 124 | + "max_seq_len": max_seqlen_k, |
| 125 | + "seq_lens": seqlens_kv, |
| 126 | + "block_table": block_table, |
| 127 | + "slot_mapping": slot_mapping, |
| 128 | + } |
| 129 | + |
| 130 | + # Set backend-specific metadata. |
| 131 | + if backend_cls == FlashAttentionBackend: |
| 132 | + attn_metadata_dict["use_cascade"] = False |
| 133 | + attn_metadata_dict["common_prefix_len"] = 0 |
| 134 | + attn_metadata_dict["cu_prefix_query_lens"] = None |
| 135 | + attn_metadata_dict["prefix_kv_lens"] = None |
| 136 | + attn_metadata_dict["suffix_kv_lens"] = None |
| 137 | + elif backend_cls == TreeAttentionBackend: |
| 138 | + # Construct the prefix bias. |
| 139 | + prefix_kv_seqlens = seqlens_kv - seqlens_q |
| 140 | + prefix_attn_bias = PagedBlockDiagonalPaddedKeysMask.from_seqlens( |
| 141 | + q_seqlen=seqlens_q.tolist(), |
| 142 | + kv_seqlen=prefix_kv_seqlens.tolist(), |
| 143 | + page_size=block_size, |
| 144 | + block_tables=block_table, |
| 145 | + device=device, |
| 146 | + ) |
| 147 | + attn_metadata_dict["prefix_attn_bias"] = prefix_attn_bias |
| 148 | + # Create a chain attn bias. |
| 149 | + chain_attn_bias = torch.triu( |
| 150 | + torch.full((q_len, q_len), |
| 151 | + float("-inf"), |
| 152 | + device=device, |
| 153 | + dtype=torch.bfloat16), |
| 154 | + diagonal=1, |
| 155 | + ) |
| 156 | + attn_metadata_dict["spec_attn_bias"] = chain_attn_bias |
| 157 | + attn_metadata_dict["prefill_attn_metadata"] = None |
| 158 | + |
| 159 | + # Initialize the backend implementation. |
| 160 | + instance = backend_cls.get_impl_cls()( |
| 161 | + num_heads=num_heads, |
| 162 | + head_size=dim_per_head, |
| 163 | + scale=softmax_scale, |
| 164 | + num_kv_heads=num_kv_heads, |
| 165 | + alibi_slopes=None, |
| 166 | + sliding_window=None, |
| 167 | + kv_cache_dtype="auto", |
| 168 | + ) |
| 169 | + |
| 170 | + # Run forward pass and store output. |
| 171 | + output = torch.empty_like(q) |
| 172 | + outputs.append( |
| 173 | + instance.forward( |
| 174 | + layer=layer, |
| 175 | + query=q, |
| 176 | + key=k, |
| 177 | + value=v, |
| 178 | + kv_cache=kv_cache.clone(), |
| 179 | + attn_metadata=backend_cls.get_metadata_cls()( |
| 180 | + **attn_metadata_dict), |
| 181 | + output=output, |
| 182 | + )) |
| 183 | + return outputs |
| 184 | + |
| 185 | + |
| 186 | +def test_tree_attn_correctness() -> None: |
| 187 | + torch.cuda.manual_seed_all(0) |
| 188 | + |
| 189 | + for batch_size in [1, 2, 16, 32, 64]: |
| 190 | + for num_heads in [2, 4]: |
| 191 | + for sequence_position in [16, 1024, 2048]: |
| 192 | + for q_len in [1, 3, 7]: |
| 193 | + flash_attn_output, tree_attn_output = forward_attention( |
| 194 | + batch_size=batch_size, |
| 195 | + num_heads=num_heads, |
| 196 | + num_kv_heads=2, |
| 197 | + dim_per_head=128, |
| 198 | + block_size=128, |
| 199 | + max_sequence_length=8192, |
| 200 | + sequence_position=sequence_position, |
| 201 | + q_len=q_len, |
| 202 | + backends=[FlashAttentionBackend, TreeAttentionBackend], |
| 203 | + ) |
| 204 | + assert torch.allclose( |
| 205 | + flash_attn_output, tree_attn_output, atol=7.81e-3 |
| 206 | + ), (f"outputs are not close for batch_size: {batch_size}, " |
| 207 | + f"num_heads: {num_heads}, " |
| 208 | + f"sequence_position: {sequence_position}, " |
| 209 | + f"q_len: {q_len}.") |
0 commit comments