Skip to content

Commit 0c4e82d

Browse files
authored
Split page indices in the ragged paged attention. (#8688)
1 parent 6016023 commit 0c4e82d

File tree

5 files changed

+398
-87
lines changed

5 files changed

+398
-87
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Usage: python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention
2+
3+
import argparse
4+
import time
5+
from typing import List, Optional, Tuple
6+
import functools
7+
8+
import torch
9+
import torch_xla
10+
import torch_xla.core.xla_model as xm
11+
import jax
12+
from jax._src import test_util as jtu
13+
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_single_query_paged_attention
14+
import jax.numpy as jnp
15+
import numpy as np
16+
17+
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_sequence_metadata, DEFAULT_MASK_VALUE
18+
19+
20+
def _ref_ragged_paged_attention(
21+
queries: jax.Array, # [num_tokens, num_q_heads, head_dim]
22+
k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
23+
v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
24+
kv_lens: jax.Array, # i32[num_tokens]
25+
page_indices: jax.Array, # i32[num_tokens, pages_per_sequence]
26+
cu_q_lens: jax.Array, # i32[num_tokens + 1]
27+
num_seqs: int,
28+
):
29+
"""This is the reference ragged paged attention implementation."""
30+
num_kv_heads, _, page_size, head_dim = k_pages.shape
31+
num_q_heads = queries.shape[1]
32+
assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0."
33+
num_query_per_kv = num_q_heads // num_kv_heads
34+
start_idx = 0
35+
outputs: List[jax.Array] = []
36+
for i in range(num_seqs):
37+
cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i]
38+
q = queries[start_idx:start_idx +
39+
cur_q_len] # [cur_q_len, num_q_heads, head_dim]
40+
41+
cur_kv_len = kv_lens[i]
42+
num_pages = (cur_kv_len + page_size - 1) // page_size
43+
page_indices_to_use = page_indices[i, :num_pages]
44+
k = k_pages[:,
45+
page_indices_to_use, :, :] # [num_kv_heads, page_indices_to_use, page_size, head_dim]
46+
k = jnp.permute_dims(
47+
k, (1, 2, 0,
48+
3)) # [page_indices_to_use, page_size, num_kv_heads, head_dim]
49+
k = jnp.reshape(
50+
k, (-1, num_kv_heads, head_dim)) # [kv_len, num_kv_heads, head_dim]
51+
k = k[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim]
52+
53+
v = v_pages[:, page_indices_to_use, :, :]
54+
v = jnp.permute_dims(v, (1, 2, 0, 3))
55+
v = jnp.reshape(v, (-1, num_kv_heads, head_dim))
56+
v = v[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim]
57+
58+
if num_query_per_kv != 1:
59+
k = jnp.repeat(k, num_query_per_kv, axis=1)
60+
v = jnp.repeat(v, num_query_per_kv, axis=1)
61+
62+
attn = jnp.einsum("qhd,khd->hqk", q, k)
63+
attn = attn.astype('float32')
64+
q_span = (cur_kv_len - cur_q_len) + jax.lax.broadcasted_iota(
65+
jnp.int32, (cur_q_len, cur_kv_len), 0)
66+
kv_span = jax.lax.broadcasted_iota(jnp.int32, (cur_q_len, cur_kv_len), 1)
67+
# Use the same DEFAULT_MASK_VALUE as in the kernel instead of float("-inf") so that the kernel can match the ref implement better.
68+
mask = jnp.where(q_span < kv_span, DEFAULT_MASK_VALUE, 0.)
69+
with jax.numpy_rank_promotion("allow"):
70+
attn = attn + mask
71+
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
72+
out = jnp.einsum("hqk,khd->qhd", attn,
73+
v) # [cur_q_len, num_q_heads, head_dim]
74+
75+
outputs.append(out)
76+
start_idx += cur_q_len
77+
78+
return jnp.concatenate(outputs, axis=0)
79+
80+
81+
def _get_closest_power_of_two(x):
82+
if x <= 0:
83+
raise ValueError(f"x must be positive. Got {x}")
84+
return 2**int(np.ceil(np.log2(x)))
85+
86+
87+
def benchmark(args):
88+
seq_lens = [
89+
(1, 1328),
90+
(5, 18),
91+
(1, 129),
92+
(120, 229),
93+
(1, 122), # end of the first physical q block
94+
(1, 64),
95+
(32, 100),
96+
(250, 463),
97+
(1, 18),
98+
(1, 17),
99+
(99, 123), # last 3 physical q blocks [(q_len, kv_len),...]
100+
]
101+
num_heads = (4, 4)
102+
head_dim = 128
103+
dtype = jnp.float32
104+
page_size = 16
105+
num_pages = 32768
106+
num_queries_per_block = 128
107+
108+
num_seqs = len(seq_lens)
109+
for i in range(num_seqs):
110+
cur_q_len = seq_lens[i][0]
111+
cur_kv_len = seq_lens[i][1]
112+
# Make sure the q_len is no longer than the kv_len. For example,
113+
# seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because
114+
# the 3rd sequence has q_len(506) > kv_len(463).
115+
assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}"
116+
117+
query_lens = [seq_len[0] for seq_len in seq_lens]
118+
num_q_tokens = sum(query_lens)
119+
kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens])
120+
num_q_heads = num_heads[0]
121+
num_kv_heads = num_heads[1]
122+
assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0."
123+
124+
prng_key = jax.random.key(0)
125+
k1, k2, k3, k4 = jax.random.split(prng_key, 4)
126+
queries = jax.random.normal(
127+
k1, (num_q_tokens, num_q_heads, head_dim), dtype=dtype)
128+
k_pages = jax.random.normal(
129+
k2, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype)
130+
v_pages = jax.random.normal(
131+
k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype)
132+
133+
# Create a kv_lens: i32[num_tokens]
134+
kv_lens_with_paddings = [0] * num_q_tokens
135+
kv_lens_with_paddings[:num_seqs] = kv_lens[:num_seqs]
136+
kv_lens_np = jnp.array(kv_lens_with_paddings)
137+
138+
# Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence]
139+
max_kv_len = max([seq_len[1] for seq_len in seq_lens])
140+
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
141+
# The reason why we need to pad max_num_pages_per_seq is that
142+
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
143+
max_num_pages_per_seq = _get_closest_power_of_two(max_num_pages_per_seq)
144+
page_indices = jax.random.randint(
145+
k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32)
146+
147+
# Create a cu_q_lens: jax.Array, # i32[num_tokens + 1]
148+
q_lens_with_paddings = [0] * num_q_tokens
149+
for i in range(num_seqs):
150+
q_lens_with_paddings[i] = query_lens[i]
151+
cu_q_lens = jnp.cumsum(jnp.array([0] + q_lens_with_paddings))
152+
153+
err, actual_output = ragged_paged_attention(
154+
queries,
155+
k_pages,
156+
v_pages,
157+
kv_lens_np,
158+
page_indices,
159+
cu_q_lens,
160+
num_seqs,
161+
num_queries_per_block=num_queries_per_block,
162+
)
163+
err.throw() # noop if there is no error.
164+
actual_output = jax.block_until_ready(actual_output)
165+
profile_path = "/workspaces/persist/myprofiles/plugins/profile"
166+
167+
def run_benchmark(num_iters: int, profile: bool = False) -> float:
168+
start_time = time.perf_counter()
169+
if profile:
170+
jax.profiler.start_trace(profile_path)
171+
172+
actual_output = None
173+
for _ in range(num_iters):
174+
if args.kernel == "ragged-paged-attention":
175+
err, actual_output = ragged_paged_attention(
176+
queries,
177+
k_pages,
178+
v_pages,
179+
kv_lens_np,
180+
page_indices,
181+
cu_q_lens,
182+
num_seqs,
183+
)
184+
err.throw()
185+
elif args.kernel == "ragged-paged-attention-ref-impl":
186+
actual_output = _ref_ragged_paged_attention(
187+
queries,
188+
k_pages,
189+
v_pages,
190+
kv_lens_np,
191+
page_indices,
192+
cu_q_lens,
193+
num_seqs,
194+
)
195+
else:
196+
assert False, f"Invalid kernel name {args.kernel}"
197+
198+
jax.block_until_ready(actual_output)
199+
200+
end_time = time.perf_counter()
201+
if profile:
202+
jax.profiler.stop_trace()
203+
return (end_time - start_time) / num_iters
204+
205+
# Warmup.
206+
print("Warming up...")
207+
run_benchmark(num_iters=3, profile=False)
208+
209+
print("Run benchmark...")
210+
if args.profile:
211+
latency = run_benchmark(num_iters=1, profile=True)
212+
else:
213+
latency = run_benchmark(num_iters=10, profile=False)
214+
print(f"Kernel running time: {latency * 1000000:.3f} us")
215+
216+
217+
if __name__ == "__main__":
218+
parser = argparse.ArgumentParser()
219+
parser.add_argument(
220+
"--kernel",
221+
type=str,
222+
choices=[
223+
"ragged-paged-attention",
224+
"ragged-paged-attention-ref-impl",
225+
],
226+
default="multi-queries-paged-attn")
227+
parser.add_argument("--profile", action="store_true")
228+
args = parser.parse_args()
229+
benchmark(args)

test/test_pallas.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def test_paged_attention_wrapper(self):
638638

639639
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
640640
"This test only works on TPUv4+.")
641-
def test_ragged_paged_attention_wrapper(self):
641+
def test_ragged_paged_attention_wrapper_without_dynamo(self):
642642
from torch_xla.experimental.custom_kernel import ragged_paged_attention
643643
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
644644

@@ -661,6 +661,7 @@ def test_ragged_paged_attention_wrapper(self):
661661
page_size = 16
662662
num_pages = 32768
663663
num_seqs = len(seq_lens)
664+
num_kv_pages_per_block = 128
664665
num_queries_per_block = 8
665666
block_kv_size = 256
666667

@@ -682,7 +683,7 @@ def test_ragged_paged_attention_wrapper(self):
682683
page_indices_xla,
683684
cu_q_lens_xla,
684685
num_seqs=num_seqs,
685-
num_kv_pages_per_block=block_kv_size // page_size,
686+
num_kv_pages_per_block=num_kv_pages_per_block,
686687
num_queries_per_block=num_queries_per_block,
687688
use_kernel=True)
688689

@@ -694,7 +695,7 @@ def test_ragged_paged_attention_wrapper(self):
694695
page_indices_xla,
695696
cu_q_lens_xla,
696697
num_seqs=num_seqs,
697-
num_kv_pages_per_block=block_kv_size // page_size,
698+
num_kv_pages_per_block=num_kv_pages_per_block,
698699
num_queries_per_block=num_queries_per_block,
699700
use_kernel=False)
700701

@@ -715,7 +716,7 @@ def test_ragged_paged_attention_wrapper(self):
715716
page_indices_jax,
716717
cu_q_lens_jax,
717718
num_seqs=num_seqs,
718-
num_kv_pages_per_block=block_kv_size // page_size,
719+
num_kv_pages_per_block=num_kv_pages_per_block,
719720
num_queries_per_block=num_queries_per_block,
720721
)[1]))
721722

@@ -748,6 +749,7 @@ def test_ragged_paged_attention_wrapper_with_dynamo(self):
748749
page_size = 16
749750
num_pages = 32768
750751
num_seqs = len(seq_lens)
752+
num_kv_pages_per_block = 128
751753
num_queries_per_block = 8
752754
block_kv_size = 256
753755

@@ -789,7 +791,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
789791
page_indices_xla,
790792
cu_q_lens_xla,
791793
num_seqs=num_seqs,
792-
num_kv_pages_per_block=block_kv_size // page_size,
794+
num_kv_pages_per_block=num_kv_pages_per_block,
793795
num_queries_per_block=num_queries_per_block,
794796
use_kernel=True,
795797
)
@@ -802,7 +804,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
802804
page_indices_xla,
803805
cu_q_lens_xla,
804806
num_seqs=num_seqs,
805-
num_kv_pages_per_block=block_kv_size // page_size,
807+
num_kv_pages_per_block=num_kv_pages_per_block,
806808
num_queries_per_block=num_queries_per_block,
807809
use_kernel=False,
808810
)

test/test_ragged_paged_attention_kernel.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _verify_ragged_paged_attention(
8686
page_size,
8787
dtype,
8888
num_pages,
89+
num_kv_pages_per_block=128,
8990
num_queries_per_block=128,
9091
):
9192
num_seqs = len(seq_lens)
@@ -124,8 +125,8 @@ def _verify_ragged_paged_attention(
124125
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
125126
# The reason why we need to pad max_num_pages_per_seq is that
126127
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
127-
max_num_pages_per_seq = self._get_closest_power_of_two(
128-
max_num_pages_per_seq)
128+
max_num_pages_per_seq = self._round_up_closest_multiple_of(
129+
max_num_pages_per_seq, num_kv_pages_per_block)
129130
# The assert below mimics the reality that each page get a unique index.
130131
# But for testing, the assert could be omitted.
131132
# assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
@@ -149,6 +150,7 @@ def _verify_ragged_paged_attention(
149150
page_indices,
150151
cu_q_lens,
151152
num_seqs,
153+
num_kv_pages_per_block=num_kv_pages_per_block,
152154
num_queries_per_block=num_queries_per_block,
153155
)
154156
err.throw() # noop if there is not err.
@@ -183,6 +185,9 @@ def _verify_ragged_paged_attention(
183185
self.assertTrue(
184186
jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol))
185187

188+
def _round_up_closest_multiple_of(self, x, base):
189+
return (x + base - 1) // base * base
190+
186191
def _get_closest_power_of_two(self, x):
187192
if x <= 0:
188193
raise ValueError(f"x must be positive. Got {x}")
@@ -225,14 +230,17 @@ def test_paged_attention_varlen_comprehensive(
225230
page_size: int,
226231
num_pages: int,
227232
):
228-
# assuming q_blk_size=128
233+
if jtu.is_device_tpu(version=4) and head_dim == 256 and page_size == 32:
234+
self.skipTest(
235+
"TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test.")
229236
self._verify_ragged_paged_attention(
230237
seq_lens,
231238
num_heads,
232239
head_dim,
233240
page_size,
234241
dtype,
235242
num_pages,
243+
num_queries_per_block=64,
236244
)
237245

238246
def test_paged_attention_mix_prefill_and_decode1(self,):
@@ -326,6 +334,7 @@ def test_paged_attention_extreme_one_tokens_per_sequence_min(self,):
326334

327335
def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
328336
# assuming q_blk_size=128
337+
# Here the q_len(1 or 511) is set up to be longer than the corresponding kv_len (0 or 256).
329338
seq_lens = [(1, 0), (511, 256)] # [(q_len, kv_len),...]
330339
num_heads = (1, 1)
331340
head_dim = 128
@@ -361,8 +370,9 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
361370
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
362371
# The reason why we need to pad max_num_pages_per_seq is that
363372
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
364-
max_num_pages_per_seq = self._get_closest_power_of_two(
365-
max_num_pages_per_seq)
373+
num_kv_pages_per_block = 128
374+
max_num_pages_per_seq = self._round_up_closest_multiple_of(
375+
max_num_pages_per_seq, num_kv_pages_per_block)
366376
# The assert below mimics the reality that each page get a unique index.
367377
# But for testing, the assert could be omitted.
368378
assert max_num_pages_per_seq * num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
@@ -388,6 +398,7 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
388398
page_indices,
389399
cu_q_lens,
390400
num_seqs,
401+
num_kv_pages_per_block=num_kv_pages_per_block,
391402
)
392403
err.throw()
393404

0 commit comments

Comments
 (0)