Skip to content

Commit 677c149

Browse files
[spec decoding] add test for tree attention correctness
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent e6d3e73 commit 677c149

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import math
5+
6+
import torch
7+
from xformers.ops.fmha.attn_bias import PagedBlockDiagonalPaddedKeysMask
8+
9+
from vllm.attention.backends.abstract import AttentionBackend
10+
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
11+
from vllm.v1.attention.backends.tree_attn import TreeAttentionBackend
12+
13+
14+
class NoOpLayerModule(torch.nn.Module):
15+
_q_scale = torch.tensor(1.0, dtype=torch.float32)
16+
_k_scale = torch.tensor(1.0, dtype=torch.float32)
17+
_v_scale = torch.tensor(1.0, dtype=torch.float32)
18+
19+
def __init__(self):
20+
super().__init__()
21+
22+
def forward(self, x):
23+
return x
24+
25+
26+
def forward_attention(
27+
batch_size: int,
28+
num_heads: int,
29+
num_kv_heads: int,
30+
dim_per_head: int,
31+
block_size: int,
32+
max_sequence_length: int,
33+
sequence_position: int,
34+
q_len: int,
35+
backends: list[type[AttentionBackend]],
36+
) -> list[torch.Tensor]:
37+
# Assert that the number of heads is divisible by the number of KV heads.
38+
assert num_heads % num_kv_heads == 0
39+
40+
device = "cuda"
41+
# Initialize q, k, and v.
42+
q = torch.randn(
43+
(batch_size * q_len, num_heads, dim_per_head),
44+
device=device,
45+
dtype=torch.bfloat16,
46+
)
47+
k = torch.randn(
48+
(batch_size * q_len, num_kv_heads, dim_per_head),
49+
device=device,
50+
dtype=torch.bfloat16,
51+
)
52+
v = torch.randn(
53+
(batch_size * q_len, num_kv_heads, dim_per_head),
54+
device=device,
55+
dtype=torch.bfloat16,
56+
)
57+
58+
# Initialize the query and KV sequence lengths.
59+
cu_seqlens_q = q_len * torch.arange(
60+
batch_size + 1, device=device, dtype=torch.int32)
61+
seqlens_q = torch.diff(cu_seqlens_q)
62+
seqlens_kv = torch.full(
63+
(batch_size, ),
64+
sequence_position + q_len,
65+
device=device,
66+
dtype=torch.int32,
67+
)
68+
max_seqlen_q = q_len
69+
max_seqlen_k = sequence_position + q_len
70+
num_actual_tokens = cu_seqlens_q[-1]
71+
72+
# Setup the block table and KV cache for paged KV.
73+
assert max_sequence_length % block_size == 0
74+
max_block_count_per_batch = max_sequence_length // block_size
75+
kv_cache = torch.randn(
76+
(
77+
2,
78+
batch_size * max_block_count_per_batch,
79+
block_size,
80+
num_kv_heads,
81+
dim_per_head,
82+
),
83+
device=device,
84+
dtype=torch.bfloat16,
85+
)
86+
num_allocated_blocks_per_batch = math.ceil(max_seqlen_k / block_size)
87+
block_table = torch.zeros(
88+
(batch_size, max_block_count_per_batch),
89+
device=device,
90+
dtype=torch.int32,
91+
)
92+
block_ids = torch.arange(
93+
0,
94+
batch_size * num_allocated_blocks_per_batch,
95+
device=device,
96+
dtype=torch.int32,
97+
).view(-1, num_allocated_blocks_per_batch)
98+
block_table[:, :num_allocated_blocks_per_batch] = block_ids
99+
100+
# Setup the slot mapping for the input KVs.
101+
slots_per_batch = []
102+
for i in range(batch_size):
103+
start_offset = block_ids[i, 0] * block_size + sequence_position
104+
slots_per_batch.append(
105+
torch.arange(
106+
start_offset,
107+
start_offset + q_len,
108+
device=device,
109+
dtype=torch.int64,
110+
))
111+
slot_mapping = torch.cat(slots_per_batch, dim=0)
112+
113+
softmax_scale = q.shape[-1]**(-0.5)
114+
layer = NoOpLayerModule()
115+
116+
# Run attention for each backend and collect the outputs.
117+
outputs = []
118+
for backend_cls in backends:
119+
# Set common metadata.
120+
attn_metadata_dict = {
121+
"num_actual_tokens": num_actual_tokens,
122+
"max_query_len": max_seqlen_q,
123+
"query_start_loc": cu_seqlens_q,
124+
"max_seq_len": max_seqlen_k,
125+
"seq_lens": seqlens_kv,
126+
"block_table": block_table,
127+
"slot_mapping": slot_mapping,
128+
}
129+
130+
# Set backend-specific metadata.
131+
if backend_cls == FlashAttentionBackend:
132+
attn_metadata_dict["use_cascade"] = False
133+
attn_metadata_dict["common_prefix_len"] = 0
134+
attn_metadata_dict["cu_prefix_query_lens"] = None
135+
attn_metadata_dict["prefix_kv_lens"] = None
136+
attn_metadata_dict["suffix_kv_lens"] = None
137+
elif backend_cls == TreeAttentionBackend:
138+
# Construct the prefix bias.
139+
prefix_kv_seqlens = seqlens_kv - seqlens_q
140+
prefix_attn_bias = PagedBlockDiagonalPaddedKeysMask.from_seqlens(
141+
q_seqlen=seqlens_q.tolist(),
142+
kv_seqlen=prefix_kv_seqlens.tolist(),
143+
page_size=block_size,
144+
block_tables=block_table,
145+
device=device,
146+
)
147+
attn_metadata_dict["prefix_attn_bias"] = prefix_attn_bias
148+
# Create a chain attn bias.
149+
chain_attn_bias = torch.triu(
150+
torch.full((q_len, q_len),
151+
float("-inf"),
152+
device=device,
153+
dtype=torch.bfloat16),
154+
diagonal=1,
155+
)
156+
attn_metadata_dict["spec_attn_bias"] = chain_attn_bias
157+
attn_metadata_dict["prefill_attn_metadata"] = None
158+
159+
# Initialize the backend implementation.
160+
instance = backend_cls.get_impl_cls()(
161+
num_heads=num_heads,
162+
head_size=dim_per_head,
163+
scale=softmax_scale,
164+
num_kv_heads=num_kv_heads,
165+
alibi_slopes=None,
166+
sliding_window=None,
167+
kv_cache_dtype="auto",
168+
)
169+
170+
# Run forward pass and store output.
171+
output = torch.empty_like(q)
172+
outputs.append(
173+
instance.forward(
174+
layer=layer,
175+
query=q,
176+
key=k,
177+
value=v,
178+
kv_cache=kv_cache.clone(),
179+
attn_metadata=backend_cls.get_metadata_cls()(
180+
**attn_metadata_dict),
181+
output=output,
182+
))
183+
return outputs
184+
185+
186+
def test_tree_attn_correctness() -> None:
187+
torch.cuda.manual_seed_all(0)
188+
189+
for batch_size in [1, 2, 16, 32, 64]:
190+
for num_heads in [2, 4]:
191+
for sequence_position in [16, 1024, 2048]:
192+
for q_len in [1, 3, 7]:
193+
flash_attn_output, tree_attn_output = forward_attention(
194+
batch_size=batch_size,
195+
num_heads=num_heads,
196+
num_kv_heads=2,
197+
dim_per_head=128,
198+
block_size=128,
199+
max_sequence_length=8192,
200+
sequence_position=sequence_position,
201+
q_len=q_len,
202+
backends=[FlashAttentionBackend, TreeAttentionBackend],
203+
)
204+
assert torch.allclose(
205+
flash_attn_output, tree_attn_output, atol=7.81e-3
206+
), f"outputs are not close for batch_size: {batch_size}, num_heads: {num_heads}, sequence_position: {sequence_position}, q_len: {q_len}."

0 commit comments

Comments
 (0)