Skip to content

Commit 232f8c6

Browse files
authored
Add end-to-end example for paged attention (#104)
* e2e example for paged attention
1 parent 2e4d04a commit 232f8c6

File tree

6 files changed

+733
-3
lines changed

6 files changed

+733
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Attention Gym is organized for easy exploration of attention mechanisms:
6969

7070
- `attn_gym.masks`: Examples creating `BlockMasks`
7171
- `attn_gym.mods`: Examples creating `score_mods`
72+
- `attn_gym.paged_attention`: Examples using `PagedAttention`
7273
- `examples/`: Detailed implementations using FlexAttention
7374

7475
## 🛠️ Dev

attn_gym/paged_attention/latency.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
Benchmarking the latency of a paged attention layer against a non-paged attention layer.
3+
4+
Command:
5+
python3 latency.py --setting change_max_seq_len
6+
"""
7+
8+
import torch
9+
from torch.nn.attention.flex_attention import (
10+
create_block_mask,
11+
noop_mask,
12+
)
13+
from torch._inductor.runtime.benchmarking import benchmarker
14+
15+
from utils import random_init_paged_attention, gen_offset, generate_score_mod
16+
17+
dtype = torch.bfloat16
18+
19+
20+
def benchmark_layer(
21+
bsz,
22+
n_heads,
23+
max_seq_len,
24+
head_dim,
25+
paged_attention,
26+
batch_idx,
27+
input_pos,
28+
block_mask,
29+
score_mod,
30+
converted_block_mask,
31+
converted_score_mod,
32+
dtype=torch.bfloat16,
33+
):
34+
from model import NonPagedAttentionLayer, PagedAttentionLayer
35+
36+
# compile model
37+
non_paged_foo = torch.compile(
38+
NonPagedAttentionLayer(bsz, n_heads, max_seq_len, head_dim, dtype), fullgraph=True
39+
)
40+
paged_foo = torch.compile(
41+
PagedAttentionLayer(n_heads, head_dim, dtype, paged_attention), fullgraph=True
42+
)
43+
44+
with torch.no_grad():
45+
# randomize a token embedding
46+
x = torch.randn(bsz, 1, n_heads * head_dim, device="cuda", dtype=dtype)
47+
48+
# warmup
49+
for _ in range(10):
50+
non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod)
51+
paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod)
52+
53+
# benchmark
54+
non_paged_latency = benchmarker.benchmark_gpu(
55+
lambda: non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod)
56+
)
57+
paged_latency = benchmarker.benchmark_gpu(
58+
lambda: paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod)
59+
)
60+
print(
61+
f"non_paged_latency: {non_paged_latency} ms, paged_latency: {paged_latency} ms, overhead: {round((paged_latency / non_paged_latency - 1.0) * 100, 2)}%"
62+
)
63+
64+
65+
def benchmark(
66+
attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_heads: int, head_dim: int
67+
):
68+
# For decoding benchmark, we set input_pos to be half of max_seq_len
69+
input_pos = torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32).view(
70+
bsz, 1
71+
) # [bsz, 1]
72+
batch_idx = torch.arange(bsz, device="cuda", dtype=torch.int32) # [bsz]
73+
74+
# init paged attention
75+
n_pages = (max_seq_len + page_size - 1) // page_size * bsz
76+
paged_attention = random_init_paged_attention(n_pages, page_size, bsz, max_seq_len)
77+
78+
# Block mask
79+
if attn_type == "causal":
80+
mask_mod = gen_offset(
81+
torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32)
82+
)
83+
else:
84+
mask_mod = noop_mask
85+
block_mask = create_block_mask(mask_mod, bsz, 1, 1, max_seq_len, BLOCK_SIZE=page_size)
86+
converted_block_mask = paged_attention.convert_logical_block_mask(block_mask)
87+
88+
# Score mod
89+
score_mod = generate_score_mod(attn_type)
90+
converted_score_mod = paged_attention.get_score_mod(score_mod)
91+
92+
benchmark_layer(
93+
bsz,
94+
n_heads,
95+
max_seq_len,
96+
head_dim,
97+
paged_attention,
98+
batch_idx,
99+
input_pos,
100+
block_mask,
101+
score_mod,
102+
converted_block_mask,
103+
converted_score_mod,
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
import argparse
109+
110+
parser = argparse.ArgumentParser()
111+
parser.add_argument("--setting", type=str, default="change_max_seq_len")
112+
args = parser.parse_args()
113+
114+
if args.setting == "change_max_seq_len":
115+
max_seq_len_candidates = [2048, 4096, 8192, 16384, 32768]
116+
bsz_candidates = [32]
117+
page_size_candidates = [128]
118+
elif args.setting == "change_bsz":
119+
max_seq_len_candidates = [8192]
120+
bsz_candidates = [32, 64, 128]
121+
page_size_candidates = [128]
122+
elif args.setting == "change_page_size":
123+
max_seq_len_candidates = [8192]
124+
bsz_candidates = [32]
125+
page_size_candidates = [64, 128, 256]
126+
else:
127+
raise NotImplementedError
128+
129+
n_heads, head_dim = 16, 64
130+
131+
for attn_type in ["noop", "causal", "rel", "head_bias"]:
132+
print(f"\nattn_type:{attn_type}")
133+
for page_size in page_size_candidates:
134+
print(f"page_size:{page_size}")
135+
for bsz in bsz_candidates:
136+
for max_seq_len in max_seq_len_candidates:
137+
torch._dynamo.reset()
138+
139+
print(
140+
f"\nbsz: {bsz}, max_seq_len: {max_seq_len}, head_dim: {head_dim}, n_heads: {n_heads}"
141+
)
142+
benchmark(attn_type, page_size, bsz, max_seq_len, n_heads, head_dim)

attn_gym/paged_attention/model.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import torch
2+
import math
3+
from torch.nn.attention.flex_attention import BlockMask, flex_attention, _score_mod_signature
4+
from torch import Tensor
5+
from typing import Dict, Optional
6+
7+
8+
class NonPagedAttentionLayer(torch.nn.Module):
9+
"""An attention layer without paged attention, ported from GPT-Fast:
10+
https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L180-L227
11+
"""
12+
13+
def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int = 32768):
14+
super().__init__()
15+
self.n_head = n_heads
16+
self.head_dim = head_dim
17+
18+
# key, query, value projections for all heads, but in a batch
19+
total_head_dim = n_heads * head_dim
20+
self.wqkv = torch.nn.Linear(
21+
total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype
22+
)
23+
self.wo = torch.nn.Linear(
24+
total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype
25+
)
26+
self.k_cache = torch.randn(
27+
(bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype
28+
)
29+
self.v_cache = torch.randn(
30+
(bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype
31+
)
32+
self.freqs_cis = precompute_freqs_cis(block_size, self.head_dim, dtype=dtype)
33+
34+
def forward(
35+
self,
36+
batch_idx: Tensor,
37+
input_pos: Tensor,
38+
x: Tensor,
39+
block_mask: BlockMask,
40+
score_mod: _score_mod_signature,
41+
) -> Tensor:
42+
# input_pos: [B, S], batch_idx: [B], x: [B, S, D]
43+
B, S, _ = x.shape
44+
45+
kv_size = self.n_head * self.head_dim
46+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
47+
48+
q = q.view(B, S, self.n_head, self.head_dim)
49+
k = k.view(B, S, self.n_head, self.head_dim)
50+
v = v.view(B, S, self.n_head, self.head_dim)
51+
52+
freqs_cis = self.freqs_cis.unsqueeze(0)[
53+
torch.zeros((B, 1), dtype=torch.int), input_pos
54+
] # [B, S, D//2, 2]
55+
56+
q = apply_rotary_emb(q, freqs_cis)
57+
k = apply_rotary_emb(k, freqs_cis)
58+
59+
q = q.transpose(1, 2)
60+
self.k_cache[batch_idx.view(B, 1), :, input_pos] = k
61+
self.v_cache[batch_idx.view(B, 1), :, input_pos] = v
62+
63+
y = flex_attention(
64+
q, self.k_cache, self.v_cache, block_mask=block_mask, score_mod=score_mod
65+
)
66+
67+
y = y.transpose(1, 2).contiguous().view(B, S, -1)
68+
69+
y = self.wo(y)
70+
return y
71+
72+
73+
class PagedAttentionLayer(torch.nn.Module):
74+
"""An attention layer with paged attention"""
75+
76+
def __init__(self, n_heads, head_dim, dtype, paged_attention, block_size: int = 65536):
77+
super().__init__()
78+
self.n_head = n_heads
79+
self.head_dim = head_dim
80+
81+
# key, query, value projections for all heads, but in a batch
82+
total_head_dim = n_heads * head_dim
83+
self.wqkv = torch.nn.Linear(
84+
total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype
85+
)
86+
self.wo = torch.nn.Linear(
87+
total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype
88+
)
89+
90+
# allocate kv cache with batch size=1 for paged attention
91+
max_cached_seq_len = paged_attention.n_pages * paged_attention.page_size
92+
self.k_cache_paged = torch.randn(
93+
1,
94+
n_heads,
95+
max_cached_seq_len,
96+
head_dim,
97+
device="cuda",
98+
dtype=dtype,
99+
)
100+
self.v_cache_paged = torch.randn(
101+
1,
102+
n_heads,
103+
max_cached_seq_len,
104+
head_dim,
105+
device="cuda",
106+
dtype=dtype,
107+
)
108+
self.paged_attention = paged_attention
109+
110+
self.freqs_cis = precompute_freqs_cis(
111+
block_size, self.head_dim, dtype=dtype
112+
) # [block_size, D//2, 2]
113+
114+
def forward(
115+
self,
116+
batch_idx: Tensor,
117+
input_pos: Tensor,
118+
x: Tensor,
119+
converted_block_mask: BlockMask,
120+
converted_score_mod: _score_mod_signature,
121+
) -> Tensor:
122+
# input_pos: [B, S], batch_idx: [B], x: [B, S, D]
123+
B, S, _ = x.shape
124+
kv_size = self.n_head * self.head_dim
125+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
126+
127+
q = q.view(B, S, self.n_head, self.head_dim)
128+
k = k.view(B, S, self.n_head, self.head_dim)
129+
v = v.view(B, S, self.n_head, self.head_dim)
130+
131+
freqs_cis = self.freqs_cis.unsqueeze(0)[
132+
torch.zeros((B, 1), dtype=torch.int), input_pos
133+
] # [B, S, D//2, 2]
134+
135+
q = apply_rotary_emb(q, freqs_cis)
136+
k = apply_rotary_emb(k, freqs_cis)
137+
138+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
139+
140+
# Comparing with NonPagedAttention, here is the only change for updating kv cache
141+
self.paged_attention.assign(
142+
batch_idx, input_pos, k, v, self.k_cache_paged, self.v_cache_paged
143+
)
144+
145+
y = flex_attention(
146+
q,
147+
self.k_cache_paged,
148+
self.v_cache_paged,
149+
block_mask=converted_block_mask,
150+
score_mod=converted_score_mod,
151+
)
152+
153+
y = y.transpose(1, 2).contiguous().view(B, S, -1)
154+
155+
y = self.wo(y)
156+
return y
157+
158+
159+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
160+
# x: [B, S, H, D], freqs_cis: [B, S, D//2, 2]
161+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2]
162+
freqs_cis = freqs_cis.view(
163+
xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2
164+
) # [B, S, 1, D//2, 2]
165+
x_out2 = torch.stack(
166+
[
167+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
168+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
169+
],
170+
-1,
171+
)
172+
173+
x_out2 = x_out2.flatten(3)
174+
return x_out2.type_as(x)
175+
176+
177+
def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Dict):
178+
factor = rope_scaling["factor"]
179+
low_freq_factor = rope_scaling["low_freq_factor"]
180+
high_freq_factor = rope_scaling["high_freq_factor"]
181+
old_context_len = rope_scaling["original_max_position_embeddings"]
182+
183+
low_freq_wavelen = old_context_len / low_freq_factor
184+
high_freq_wavelen = old_context_len / high_freq_factor
185+
new_freqs = []
186+
for freq in freqs:
187+
wavelen = 2 * math.pi / freq
188+
if wavelen < high_freq_wavelen:
189+
new_freqs.append(freq)
190+
elif wavelen > low_freq_wavelen:
191+
new_freqs.append(freq / factor)
192+
else:
193+
assert low_freq_wavelen != high_freq_wavelen
194+
smooth = (old_context_len / wavelen - low_freq_factor) / (
195+
high_freq_factor - low_freq_factor
196+
)
197+
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
198+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
199+
200+
201+
def precompute_freqs_cis(
202+
seq_len: int,
203+
n_elem: int,
204+
base: int = 10000,
205+
dtype: torch.dtype = torch.bfloat16,
206+
rope_scaling: Optional[dict] = None,
207+
) -> Tensor:
208+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
209+
if rope_scaling is not None:
210+
freqs = apply_rope_scaling(freqs, rope_scaling)
211+
t = torch.arange(seq_len, device=freqs.device)
212+
freqs = torch.outer(t, freqs)
213+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
214+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
215+
return cache.to(dtype=dtype, device="cuda")

0 commit comments

Comments
 (0)