@@ -93,14 +93,18 @@ def paged_flash_attention_kernel(
93
93
pages_per_sequence : int ,
94
94
mask_value : float ,
95
95
megacore_mode : str ,
96
+ program_ids = (),
96
97
):
97
98
"""Pallas kernel for paged attention."""
98
- core_index , b , h , i = (
99
- pl .program_id (0 ),
100
- pl .program_id (1 ),
101
- pl .program_id (2 ),
102
- pl .program_id (3 ),
103
- )
99
+ if program_ids :
100
+ core_index , b , h , i = program_ids
101
+ else :
102
+ core_index , b , h , i = (
103
+ pl .program_id (0 ),
104
+ pl .program_id (1 ),
105
+ pl .program_id (2 ),
106
+ pl .program_id (3 ),
107
+ )
104
108
num_kv_heads , _ , page_size , _ = k_pages_hbm_ref .shape
105
109
bk = page_size * pages_per_compute_block
106
110
num_cores = pl .num_programs (0 )
@@ -229,12 +233,64 @@ def prefetch_next_block():
229
233
step_ref [0 ] = step + 1
230
234
231
235
236
+ def paged_flash_attention_kernel_inline_seq_dim (
237
+ lengths_ref ,
238
+ page_indices_ref ,
239
+ buffer_index_ref ,
240
+ step_ref ,
241
+ q_ref ,
242
+ k_pages_hbm_ref ,
243
+ v_pages_hbm_ref ,
244
+ o_ref ,
245
+ m_ref ,
246
+ l_ref ,
247
+ k_vmem_buffer ,
248
+ v_vmem_buffer ,
249
+ sem ,
250
+ * ,
251
+ batch_size : int ,
252
+ pages_per_compute_block : int ,
253
+ pages_per_sequence : int ,
254
+ mask_value : float ,
255
+ megacore_mode : str ,
256
+ ):
257
+ core_index , b , h = pl .program_id (0 ), pl .program_id (1 ), pl .program_id (2 )
258
+
259
+ def body (i , _ ):
260
+ paged_flash_attention_kernel (
261
+ lengths_ref ,
262
+ page_indices_ref ,
263
+ buffer_index_ref ,
264
+ step_ref ,
265
+ q_ref ,
266
+ k_pages_hbm_ref ,
267
+ v_pages_hbm_ref ,
268
+ o_ref ,
269
+ m_ref ,
270
+ l_ref ,
271
+ k_vmem_buffer ,
272
+ v_vmem_buffer ,
273
+ sem ,
274
+ batch_size = batch_size ,
275
+ pages_per_compute_block = pages_per_compute_block ,
276
+ pages_per_sequence = pages_per_sequence ,
277
+ mask_value = mask_value ,
278
+ megacore_mode = megacore_mode ,
279
+ program_ids = (core_index , b , h , i ),
280
+ )
281
+ return ()
282
+
283
+ bk = pages_per_compute_block * k_pages_hbm_ref .shape [- 2 ]
284
+ lax .fori_loop (0 , lax .div (lengths_ref [b ] + bk - 1 , bk ), body , ())
285
+
286
+
232
287
@functools .partial (
233
288
jax .jit ,
234
289
static_argnames = [
235
290
"pages_per_compute_block" ,
236
291
"mask_value" ,
237
292
"megacore_mode" ,
293
+ "inline_seq_dim" ,
238
294
],
239
295
)
240
296
def paged_attention (
@@ -247,6 +303,7 @@ def paged_attention(
247
303
mask_value : float = DEFAULT_MASK_VALUE ,
248
304
pages_per_compute_block : int ,
249
305
megacore_mode : Optional [str ] = None ,
306
+ inline_seq_dim : bool = True ,
250
307
) -> tuple [jax .Array , tuple [jax .Array , jax .Array ]]:
251
308
"""Paged grouped query attention.
252
309
@@ -271,6 +328,8 @@ def paged_attention(
271
328
divisible by 2.
272
329
* batch: megacore parallelism on batch dimension; requires batch divisible
273
330
by 2.
331
+ inline_seq_dim: whether to fuse kernel instances along the sequence dim into
332
+ one kernel.
274
333
275
334
Returns:
276
335
The output of attention([batch_size, num_heads, head_dim]).
@@ -362,9 +421,31 @@ def paged_attention(
362
421
)
363
422
q_dtype_for_kernel_launch = q .dtype
364
423
424
+ if inline_seq_dim :
425
+ kernel = paged_flash_attention_kernel_inline_seq_dim
426
+ grid = (
427
+ num_cores ,
428
+ batch_size // num_cores if megacore_mode == "batch" else batch_size ,
429
+ num_kv_heads // num_cores
430
+ if megacore_mode == "kv_head"
431
+ else num_kv_heads ,
432
+ ) # type: ignore
433
+ dimension_sematics = ("parallel" , "arbitrary" , "arbitrary" ) # type: ignore
434
+ else :
435
+ kernel = paged_flash_attention_kernel
436
+ grid = (
437
+ num_cores ,
438
+ batch_size // num_cores if megacore_mode == "batch" else batch_size ,
439
+ num_kv_heads // num_cores
440
+ if megacore_mode == "kv_head"
441
+ else num_kv_heads ,
442
+ pages_per_sequence // pages_per_compute_block ,
443
+ ) # type: ignore
444
+ dimension_sematics = ("parallel" , "arbitrary" , "arbitrary" , "arbitrary" ) # type: ignore
445
+
365
446
out , _ , _ = pl .pallas_call (
366
447
functools .partial (
367
- paged_flash_attention_kernel ,
448
+ kernel ,
368
449
pages_per_sequence = pages_per_sequence ,
369
450
batch_size = batch_size ,
370
451
pages_per_compute_block = pages_per_compute_block ,
@@ -383,16 +464,7 @@ def paged_attention(
383
464
q_block_spec ,
384
465
q_block_spec ,
385
466
],
386
- grid = (
387
- num_cores ,
388
- batch_size // num_cores
389
- if megacore_mode == "batch"
390
- else batch_size ,
391
- num_kv_heads // num_cores
392
- if megacore_mode == "kv_head"
393
- else num_kv_heads ,
394
- pages_per_sequence // pages_per_compute_block ,
395
- ),
467
+ grid = grid ,
396
468
scratch_shapes = (
397
469
pltpu .VMEM (
398
470
(
@@ -416,12 +488,7 @@ def paged_attention(
416
488
),
417
489
),
418
490
mosaic_params = dict (
419
- dimension_semantics = (
420
- "parallel" ,
421
- "arbitrary" ,
422
- "arbitrary" ,
423
- "arbitrary" ,
424
- )
491
+ dimension_semantics = dimension_sematics ,
425
492
),
426
493
out_shape = [
427
494
jax .ShapeDtypeStruct (q .shape , q_dtype_for_kernel_launch ),
0 commit comments