Skip to content

Commit 026f309

Browse files
voutcnjax authors
authored andcommitted
Introduce an "inline_seq_dim" mode to paged attention kernel. The mode will fuse kernel instances along the sequence dimension into one kernel, hence reducing the number of kernels.
PiperOrigin-RevId: 621611672
1 parent 0624775 commit 026f309

File tree

1 file changed

+90
-23
lines changed

1 file changed

+90
-23
lines changed

jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,18 @@ def paged_flash_attention_kernel(
9393
pages_per_sequence: int,
9494
mask_value: float,
9595
megacore_mode: str,
96+
program_ids=(),
9697
):
9798
"""Pallas kernel for paged attention."""
98-
core_index, b, h, i = (
99-
pl.program_id(0),
100-
pl.program_id(1),
101-
pl.program_id(2),
102-
pl.program_id(3),
103-
)
99+
if program_ids:
100+
core_index, b, h, i = program_ids
101+
else:
102+
core_index, b, h, i = (
103+
pl.program_id(0),
104+
pl.program_id(1),
105+
pl.program_id(2),
106+
pl.program_id(3),
107+
)
104108
num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape
105109
bk = page_size * pages_per_compute_block
106110
num_cores = pl.num_programs(0)
@@ -229,12 +233,64 @@ def prefetch_next_block():
229233
step_ref[0] = step + 1
230234

231235

236+
def paged_flash_attention_kernel_inline_seq_dim(
237+
lengths_ref,
238+
page_indices_ref,
239+
buffer_index_ref,
240+
step_ref,
241+
q_ref,
242+
k_pages_hbm_ref,
243+
v_pages_hbm_ref,
244+
o_ref,
245+
m_ref,
246+
l_ref,
247+
k_vmem_buffer,
248+
v_vmem_buffer,
249+
sem,
250+
*,
251+
batch_size: int,
252+
pages_per_compute_block: int,
253+
pages_per_sequence: int,
254+
mask_value: float,
255+
megacore_mode: str,
256+
):
257+
core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2)
258+
259+
def body(i, _):
260+
paged_flash_attention_kernel(
261+
lengths_ref,
262+
page_indices_ref,
263+
buffer_index_ref,
264+
step_ref,
265+
q_ref,
266+
k_pages_hbm_ref,
267+
v_pages_hbm_ref,
268+
o_ref,
269+
m_ref,
270+
l_ref,
271+
k_vmem_buffer,
272+
v_vmem_buffer,
273+
sem,
274+
batch_size=batch_size,
275+
pages_per_compute_block=pages_per_compute_block,
276+
pages_per_sequence=pages_per_sequence,
277+
mask_value=mask_value,
278+
megacore_mode=megacore_mode,
279+
program_ids=(core_index, b, h, i),
280+
)
281+
return ()
282+
283+
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2]
284+
lax.fori_loop(0, lax.div(lengths_ref[b] + bk - 1, bk), body, ())
285+
286+
232287
@functools.partial(
233288
jax.jit,
234289
static_argnames=[
235290
"pages_per_compute_block",
236291
"mask_value",
237292
"megacore_mode",
293+
"inline_seq_dim",
238294
],
239295
)
240296
def paged_attention(
@@ -247,6 +303,7 @@ def paged_attention(
247303
mask_value: float = DEFAULT_MASK_VALUE,
248304
pages_per_compute_block: int,
249305
megacore_mode: Optional[str] = None,
306+
inline_seq_dim: bool = True,
250307
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
251308
"""Paged grouped query attention.
252309
@@ -271,6 +328,8 @@ def paged_attention(
271328
divisible by 2.
272329
* batch: megacore parallelism on batch dimension; requires batch divisible
273330
by 2.
331+
inline_seq_dim: whether to fuse kernel instances along the sequence dim into
332+
one kernel.
274333
275334
Returns:
276335
The output of attention([batch_size, num_heads, head_dim]).
@@ -362,9 +421,31 @@ def paged_attention(
362421
)
363422
q_dtype_for_kernel_launch = q.dtype
364423

424+
if inline_seq_dim:
425+
kernel = paged_flash_attention_kernel_inline_seq_dim
426+
grid = (
427+
num_cores,
428+
batch_size // num_cores if megacore_mode == "batch" else batch_size,
429+
num_kv_heads // num_cores
430+
if megacore_mode == "kv_head"
431+
else num_kv_heads,
432+
) # type: ignore
433+
dimension_sematics = ("parallel", "arbitrary", "arbitrary") # type: ignore
434+
else:
435+
kernel = paged_flash_attention_kernel
436+
grid = (
437+
num_cores,
438+
batch_size // num_cores if megacore_mode == "batch" else batch_size,
439+
num_kv_heads // num_cores
440+
if megacore_mode == "kv_head"
441+
else num_kv_heads,
442+
pages_per_sequence // pages_per_compute_block,
443+
) # type: ignore
444+
dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore
445+
365446
out, _, _ = pl.pallas_call(
366447
functools.partial(
367-
paged_flash_attention_kernel,
448+
kernel,
368449
pages_per_sequence=pages_per_sequence,
369450
batch_size=batch_size,
370451
pages_per_compute_block=pages_per_compute_block,
@@ -383,16 +464,7 @@ def paged_attention(
383464
q_block_spec,
384465
q_block_spec,
385466
],
386-
grid=(
387-
num_cores,
388-
batch_size // num_cores
389-
if megacore_mode == "batch"
390-
else batch_size,
391-
num_kv_heads // num_cores
392-
if megacore_mode == "kv_head"
393-
else num_kv_heads,
394-
pages_per_sequence // pages_per_compute_block,
395-
),
467+
grid=grid,
396468
scratch_shapes=(
397469
pltpu.VMEM(
398470
(
@@ -416,12 +488,7 @@ def paged_attention(
416488
),
417489
),
418490
mosaic_params=dict(
419-
dimension_semantics=(
420-
"parallel",
421-
"arbitrary",
422-
"arbitrary",
423-
"arbitrary",
424-
)
491+
dimension_semantics=dimension_sematics,
425492
),
426493
out_shape=[
427494
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),

0 commit comments

Comments
 (0)