1
- # Usage: python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention
1
+ # Usage: python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention-with-torch-xla-dynamo
2
+ # python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention-with-torch-xla-nondynamo
3
+ # python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention
2
4
3
5
import argparse
4
6
import time
5
7
from typing import List , Optional , Tuple
6
8
import functools
9
+ import os
10
+ import sys
7
11
8
12
import torch
9
- import torch_xla
13
+ import torch_xla .debug .profiler as xp
14
+ import torch_xla .experimental .custom_kernel # Required to register custom ops.
10
15
import torch_xla .core .xla_model as xm
11
- import jax
12
- from jax ._src import test_util as jtu
13
- from jax .experimental .pallas .ops .tpu .paged_attention .paged_attention_kernel import paged_attention as jax_single_query_paged_attention
14
- import jax .numpy as jnp
16
+ from torch_xla import runtime as xr
15
17
import numpy as np
16
18
17
- from torch_xla .experimental .pallas_kernels .ragged_paged_attention_kernel import ragged_paged_attention , make_sequence_metadata , DEFAULT_MASK_VALUE
19
+ if xr .device_type () == 'TPU' :
20
+ from torch_xla .experimental .custom_kernel import jax_import_guard
21
+ jax_import_guard ()
22
+ import jax
23
+ import jax .numpy as jnp
24
+ from jax .experimental import pallas as pl
25
+ from torch_xla .experimental .pallas_kernels .ragged_paged_attention_kernel import ragged_paged_attention , make_sequence_metadata , DEFAULT_MASK_VALUE
18
26
19
27
20
28
def _ref_ragged_paged_attention (
@@ -78,10 +86,12 @@ def _ref_ragged_paged_attention(
78
86
return jnp .concatenate (outputs , axis = 0 )
79
87
80
88
81
- def _get_closest_power_of_two (x ):
82
- if x <= 0 :
83
- raise ValueError (f"x must be positive. Got { x } " )
84
- return 2 ** int (np .ceil (np .log2 (x )))
89
+ def _get_closest_of_multiple (x , base ):
90
+ return (x + base - 1 ) // base * base
91
+
92
+
93
+ def _run_with_torch_xla (kernel ):
94
+ return "torch-xla" in kernel
85
95
86
96
87
97
def benchmark (args ):
@@ -103,15 +113,13 @@ def benchmark(args):
103
113
dtype = jnp .float32
104
114
page_size = 16
105
115
num_pages = 32768
106
- num_queries_per_block = 128
116
+ num_queries_per_block = args .num_queries_per_block
117
+ num_kv_pages_per_block = args .num_kv_pages_per_block
107
118
108
119
num_seqs = len (seq_lens )
109
120
for i in range (num_seqs ):
110
121
cur_q_len = seq_lens [i ][0 ]
111
122
cur_kv_len = seq_lens [i ][1 ]
112
- # Make sure the q_len is no longer than the kv_len. For example,
113
- # seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because
114
- # the 3rd sequence has q_len(506) > kv_len(463).
115
123
assert cur_q_len <= cur_kv_len , f"cur_q_len must be less than or equal to cur_kv_len. Got { cur_q_len } and { cur_kv_len } "
116
124
117
125
query_lens = [seq_len [0 ] for seq_len in seq_lens ]
@@ -140,7 +148,8 @@ def benchmark(args):
140
148
max_num_pages_per_seq = (max_kv_len + page_size - 1 ) // page_size
141
149
# The reason why we need to pad max_num_pages_per_seq is that
142
150
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
143
- max_num_pages_per_seq = _get_closest_power_of_two (max_num_pages_per_seq )
151
+ max_num_pages_per_seq = _get_closest_of_multiple (max_num_pages_per_seq ,
152
+ num_kv_pages_per_block )
144
153
page_indices = jax .random .randint (
145
154
k4 , (num_q_tokens , max_num_pages_per_seq ), 0 , num_pages , dtype = jnp .int32 )
146
155
@@ -150,28 +159,68 @@ def benchmark(args):
150
159
q_lens_with_paddings [i ] = query_lens [i ]
151
160
cu_q_lens = jnp .cumsum (jnp .array ([0 ] + q_lens_with_paddings ))
152
161
153
- err , actual_output = ragged_paged_attention (
154
- queries ,
155
- k_pages ,
156
- v_pages ,
157
- kv_lens_np ,
158
- page_indices ,
159
- cu_q_lens ,
160
- num_seqs ,
161
- num_queries_per_block = num_queries_per_block ,
162
- )
163
- err .throw () # noop if there is no error.
164
- actual_output = jax .block_until_ready (actual_output )
165
- profile_path = "/workspaces/persist/myprofiles/plugins/profile"
162
+ if _run_with_torch_xla (args .kernel ):
163
+ queries_xla = torch .from_numpy (np .array (queries )).to (
164
+ torch .bfloat16 ).to ("xla" )
165
+ k_pages_xla = torch .from_numpy (np .array (k_pages )).to (
166
+ torch .bfloat16 ).to ("xla" )
167
+ v_pages_xla = torch .from_numpy (np .array (v_pages )).to (
168
+ torch .bfloat16 ).to ("xla" )
169
+ kv_lens_xla = torch .from_numpy (np .array (kv_lens_np )).to ("xla" )
170
+ page_indices_xla = torch .from_numpy (np .array (page_indices )).to ("xla" )
171
+ cu_q_lens_xla = torch .from_numpy (np .array (cu_q_lens )).to ("xla" )
166
172
167
- def run_benchmark (num_iters : int , profile : bool = False ) -> float :
173
+ def ragged_paged_attention_wrapper (q , k_pages , v_pages , kv_lens ,
174
+ page_indices , cu_q_lens , num_seqs ,
175
+ num_kv_pages_per_block ,
176
+ num_queries_per_block , use_kernel ):
177
+ return torch .ops .xla .ragged_paged_attention (
178
+ q ,
179
+ k_pages ,
180
+ v_pages ,
181
+ kv_lens ,
182
+ page_indices ,
183
+ cu_q_lens ,
184
+ num_seqs ,
185
+ num_kv_pages_per_block ,
186
+ num_queries_per_block ,
187
+ use_kernel = use_kernel ,
188
+ )
189
+
190
+ compiled_paged_attention = torch .compile (
191
+ ragged_paged_attention_wrapper , backend = "openxla" )
192
+
193
+ def run_benchmark (num_iters : int ) -> float :
168
194
start_time = time .perf_counter ()
169
- if profile :
170
- jax .profiler .start_trace (profile_path )
171
195
172
- actual_output = None
173
196
for _ in range (num_iters ):
174
- if args .kernel == "ragged-paged-attention" :
197
+ if args .kernel == "ragged-paged-attention-with-torch-xla-dynamo" :
198
+ compiled_paged_attention (
199
+ queries_xla ,
200
+ k_pages_xla ,
201
+ v_pages_xla ,
202
+ kv_lens_xla ,
203
+ page_indices_xla ,
204
+ cu_q_lens_xla ,
205
+ num_seqs ,
206
+ num_queries_per_block = num_queries_per_block ,
207
+ num_kv_pages_per_block = num_kv_pages_per_block ,
208
+ use_kernel = True ,
209
+ )
210
+ elif args .kernel == "ragged-paged-attention-with-torch-xla-nondynamo" :
211
+ torch .ops .xla .ragged_paged_attention (
212
+ queries_xla ,
213
+ k_pages_xla ,
214
+ v_pages_xla ,
215
+ kv_lens_xla ,
216
+ page_indices_xla ,
217
+ cu_q_lens_xla ,
218
+ num_seqs ,
219
+ num_queries_per_block = num_queries_per_block ,
220
+ num_kv_pages_per_block = num_kv_pages_per_block ,
221
+ use_kernel = True ,
222
+ )
223
+ elif args .kernel == "ragged-paged-attention" :
175
224
err , actual_output = ragged_paged_attention (
176
225
queries ,
177
226
k_pages ,
@@ -180,6 +229,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
180
229
page_indices ,
181
230
cu_q_lens ,
182
231
num_seqs ,
232
+ num_queries_per_block = num_queries_per_block ,
233
+ num_kv_pages_per_block = num_kv_pages_per_block ,
183
234
)
184
235
err .throw ()
185
236
elif args .kernel == "ragged-paged-attention-ref-impl" :
@@ -195,22 +246,23 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
195
246
else :
196
247
assert False , f"Invalid kernel name { args .kernel } "
197
248
198
- jax .block_until_ready (actual_output )
249
+ if _run_with_torch_xla (args .kernel ):
250
+ xm .mark_step ()
251
+ xm .wait_device_ops ()
252
+ else :
253
+ jax .block_until_ready (actual_output )
199
254
200
255
end_time = time .perf_counter ()
201
- if profile :
202
- jax .profiler .stop_trace ()
203
256
return (end_time - start_time ) / num_iters
204
257
205
258
# Warmup.
206
259
print ("Warming up..." )
207
- run_benchmark (num_iters = 3 , profile = False )
260
+ run_benchmark (num_iters = 3 )
208
261
209
- print ("Run benchmark..." )
210
- if args .profile :
211
- latency = run_benchmark (num_iters = 1 , profile = True )
212
- else :
213
- latency = run_benchmark (num_iters = 10 , profile = False )
262
+ print (
263
+ f"Run benchmark with { num_queries_per_block = } , { num_kv_pages_per_block = } ..."
264
+ )
265
+ latency = run_benchmark (num_iters = 10 )
214
266
print (f"Kernel running time: { latency * 1000000 :.3f} us" )
215
267
216
268
@@ -221,9 +273,13 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
221
273
type = str ,
222
274
choices = [
223
275
"ragged-paged-attention" ,
276
+ "ragged-paged-attention-with-torch-xla-dynamo" ,
277
+ "ragged-paged-attention-with-torch-xla-nondynamo" ,
224
278
"ragged-paged-attention-ref-impl" ,
225
279
],
226
280
default = "multi-queries-paged-attn" )
227
- parser .add_argument ("--profile" , action = "store_true" )
281
+ parser .add_argument ("--num-queries-per-block" , type = int , default = 128 )
282
+ parser .add_argument ("--num-kv-pages-per-block" , type = int , default = 128 )
228
283
args = parser .parse_args ()
284
+
229
285
benchmark (args )
0 commit comments