Skip to content

Commit 79d1e86

Browse files
authored
Improve the ragged kernel benchmarking script. (#8733)
1 parent 1acc987 commit 79d1e86

File tree

1 file changed

+100
-44
lines changed

1 file changed

+100
-44
lines changed

test/benchmarks/test_ragged_paged_attention_benchmark.py

Lines changed: 100 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
1-
# Usage: python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention
1+
# Usage: python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention-with-torch-xla-dynamo
2+
# python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention-with-torch-xla-nondynamo
3+
# python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention
24

35
import argparse
46
import time
57
from typing import List, Optional, Tuple
68
import functools
9+
import os
10+
import sys
711

812
import torch
9-
import torch_xla
13+
import torch_xla.debug.profiler as xp
14+
import torch_xla.experimental.custom_kernel # Required to register custom ops.
1015
import torch_xla.core.xla_model as xm
11-
import jax
12-
from jax._src import test_util as jtu
13-
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_single_query_paged_attention
14-
import jax.numpy as jnp
16+
from torch_xla import runtime as xr
1517
import numpy as np
1618

17-
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_sequence_metadata, DEFAULT_MASK_VALUE
19+
if xr.device_type() == 'TPU':
20+
from torch_xla.experimental.custom_kernel import jax_import_guard
21+
jax_import_guard()
22+
import jax
23+
import jax.numpy as jnp
24+
from jax.experimental import pallas as pl
25+
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_sequence_metadata, DEFAULT_MASK_VALUE
1826

1927

2028
def _ref_ragged_paged_attention(
@@ -78,10 +86,12 @@ def _ref_ragged_paged_attention(
7886
return jnp.concatenate(outputs, axis=0)
7987

8088

81-
def _get_closest_power_of_two(x):
82-
if x <= 0:
83-
raise ValueError(f"x must be positive. Got {x}")
84-
return 2**int(np.ceil(np.log2(x)))
89+
def _get_closest_of_multiple(x, base):
90+
return (x + base - 1) // base * base
91+
92+
93+
def _run_with_torch_xla(kernel):
94+
return "torch-xla" in kernel
8595

8696

8797
def benchmark(args):
@@ -103,15 +113,13 @@ def benchmark(args):
103113
dtype = jnp.float32
104114
page_size = 16
105115
num_pages = 32768
106-
num_queries_per_block = 128
116+
num_queries_per_block = args.num_queries_per_block
117+
num_kv_pages_per_block = args.num_kv_pages_per_block
107118

108119
num_seqs = len(seq_lens)
109120
for i in range(num_seqs):
110121
cur_q_len = seq_lens[i][0]
111122
cur_kv_len = seq_lens[i][1]
112-
# Make sure the q_len is no longer than the kv_len. For example,
113-
# seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because
114-
# the 3rd sequence has q_len(506) > kv_len(463).
115123
assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}"
116124

117125
query_lens = [seq_len[0] for seq_len in seq_lens]
@@ -140,7 +148,8 @@ def benchmark(args):
140148
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
141149
# The reason why we need to pad max_num_pages_per_seq is that
142150
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
143-
max_num_pages_per_seq = _get_closest_power_of_two(max_num_pages_per_seq)
151+
max_num_pages_per_seq = _get_closest_of_multiple(max_num_pages_per_seq,
152+
num_kv_pages_per_block)
144153
page_indices = jax.random.randint(
145154
k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32)
146155

@@ -150,28 +159,68 @@ def benchmark(args):
150159
q_lens_with_paddings[i] = query_lens[i]
151160
cu_q_lens = jnp.cumsum(jnp.array([0] + q_lens_with_paddings))
152161

153-
err, actual_output = ragged_paged_attention(
154-
queries,
155-
k_pages,
156-
v_pages,
157-
kv_lens_np,
158-
page_indices,
159-
cu_q_lens,
160-
num_seqs,
161-
num_queries_per_block=num_queries_per_block,
162-
)
163-
err.throw() # noop if there is no error.
164-
actual_output = jax.block_until_ready(actual_output)
165-
profile_path = "/workspaces/persist/myprofiles/plugins/profile"
162+
if _run_with_torch_xla(args.kernel):
163+
queries_xla = torch.from_numpy(np.array(queries)).to(
164+
torch.bfloat16).to("xla")
165+
k_pages_xla = torch.from_numpy(np.array(k_pages)).to(
166+
torch.bfloat16).to("xla")
167+
v_pages_xla = torch.from_numpy(np.array(v_pages)).to(
168+
torch.bfloat16).to("xla")
169+
kv_lens_xla = torch.from_numpy(np.array(kv_lens_np)).to("xla")
170+
page_indices_xla = torch.from_numpy(np.array(page_indices)).to("xla")
171+
cu_q_lens_xla = torch.from_numpy(np.array(cu_q_lens)).to("xla")
166172

167-
def run_benchmark(num_iters: int, profile: bool = False) -> float:
173+
def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
174+
page_indices, cu_q_lens, num_seqs,
175+
num_kv_pages_per_block,
176+
num_queries_per_block, use_kernel):
177+
return torch.ops.xla.ragged_paged_attention(
178+
q,
179+
k_pages,
180+
v_pages,
181+
kv_lens,
182+
page_indices,
183+
cu_q_lens,
184+
num_seqs,
185+
num_kv_pages_per_block,
186+
num_queries_per_block,
187+
use_kernel=use_kernel,
188+
)
189+
190+
compiled_paged_attention = torch.compile(
191+
ragged_paged_attention_wrapper, backend="openxla")
192+
193+
def run_benchmark(num_iters: int) -> float:
168194
start_time = time.perf_counter()
169-
if profile:
170-
jax.profiler.start_trace(profile_path)
171195

172-
actual_output = None
173196
for _ in range(num_iters):
174-
if args.kernel == "ragged-paged-attention":
197+
if args.kernel == "ragged-paged-attention-with-torch-xla-dynamo":
198+
compiled_paged_attention(
199+
queries_xla,
200+
k_pages_xla,
201+
v_pages_xla,
202+
kv_lens_xla,
203+
page_indices_xla,
204+
cu_q_lens_xla,
205+
num_seqs,
206+
num_queries_per_block=num_queries_per_block,
207+
num_kv_pages_per_block=num_kv_pages_per_block,
208+
use_kernel=True,
209+
)
210+
elif args.kernel == "ragged-paged-attention-with-torch-xla-nondynamo":
211+
torch.ops.xla.ragged_paged_attention(
212+
queries_xla,
213+
k_pages_xla,
214+
v_pages_xla,
215+
kv_lens_xla,
216+
page_indices_xla,
217+
cu_q_lens_xla,
218+
num_seqs,
219+
num_queries_per_block=num_queries_per_block,
220+
num_kv_pages_per_block=num_kv_pages_per_block,
221+
use_kernel=True,
222+
)
223+
elif args.kernel == "ragged-paged-attention":
175224
err, actual_output = ragged_paged_attention(
176225
queries,
177226
k_pages,
@@ -180,6 +229,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
180229
page_indices,
181230
cu_q_lens,
182231
num_seqs,
232+
num_queries_per_block=num_queries_per_block,
233+
num_kv_pages_per_block=num_kv_pages_per_block,
183234
)
184235
err.throw()
185236
elif args.kernel == "ragged-paged-attention-ref-impl":
@@ -195,22 +246,23 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
195246
else:
196247
assert False, f"Invalid kernel name {args.kernel}"
197248

198-
jax.block_until_ready(actual_output)
249+
if _run_with_torch_xla(args.kernel):
250+
xm.mark_step()
251+
xm.wait_device_ops()
252+
else:
253+
jax.block_until_ready(actual_output)
199254

200255
end_time = time.perf_counter()
201-
if profile:
202-
jax.profiler.stop_trace()
203256
return (end_time - start_time) / num_iters
204257

205258
# Warmup.
206259
print("Warming up...")
207-
run_benchmark(num_iters=3, profile=False)
260+
run_benchmark(num_iters=3)
208261

209-
print("Run benchmark...")
210-
if args.profile:
211-
latency = run_benchmark(num_iters=1, profile=True)
212-
else:
213-
latency = run_benchmark(num_iters=10, profile=False)
262+
print(
263+
f"Run benchmark with {num_queries_per_block=}, {num_kv_pages_per_block=} ..."
264+
)
265+
latency = run_benchmark(num_iters=10)
214266
print(f"Kernel running time: {latency * 1000000:.3f} us")
215267

216268

@@ -221,9 +273,13 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
221273
type=str,
222274
choices=[
223275
"ragged-paged-attention",
276+
"ragged-paged-attention-with-torch-xla-dynamo",
277+
"ragged-paged-attention-with-torch-xla-nondynamo",
224278
"ragged-paged-attention-ref-impl",
225279
],
226280
default="multi-queries-paged-attn")
227-
parser.add_argument("--profile", action="store_true")
281+
parser.add_argument("--num-queries-per-block", type=int, default=128)
282+
parser.add_argument("--num-kv-pages-per-block", type=int, default=128)
228283
args = parser.parse_args()
284+
229285
benchmark(args)

0 commit comments

Comments
 (0)