Skip to content

Commit 51a31e5

Browse files
bythew3ijax authors
authored andcommitted
[Pallas] Use scratch_shapes for scratch operands in flash attention kernel.
PiperOrigin-RevId: 611884935
1 parent 28f84eb commit 51a31e5

File tree

1 file changed

+39
-39
lines changed

1 file changed

+39
-39
lines changed

jax/experimental/pallas/ops/tpu/flash_attention.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,11 @@ def _flash_attention_kernel_single_batch(
347347
q_segment_ids_tile_ref,
348348
kv_segment_ids_tile_ref, # Input arrays
349349
o_tile_ref, # Output arrays
350+
l_ref,
351+
m_ref,
350352
m_scratch_ref,
351353
l_scratch_ref,
352354
acc_scratch_ref,
353-
l_ref: Any | None = None,
354-
m_ref: Any | None = None,
355355
*,
356356
causal,
357357
sm_scale,
@@ -496,9 +496,6 @@ def _flash_attention_kernel_single_batch_single_step(
496496
q_segment_ids_tile_ref,
497497
kv_segment_ids_tile_ref, # Input arrays
498498
o_tile_ref, # Output arrays
499-
m_scratch_ref,
500-
l_scratch_ref,
501-
acc_scratch_ref,
502499
l_ref: Any | None = None,
503500
m_ref: Any | None = None,
504501
*,
@@ -511,8 +508,6 @@ def _flash_attention_kernel_single_batch_single_step(
511508
block_k_major = k_tile_ref.shape[2]
512509
block_q = q_tile_ref.shape[2]
513510

514-
scratch_refs = (m_scratch_ref, l_scratch_ref, acc_scratch_ref)
515-
assert all(ref is None for ref in scratch_refs)
516511
assert kv_seq_len == block_k_major == block_k
517512

518513
q = q_tile_ref[batch_idx] # [block_q, head_dim]
@@ -656,19 +651,12 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
656651
out_specs = [pl.BlockSpec(o_index_map, (block_b, 1, block_q, head_dim))]
657652

658653
if block_k != kv_seq_len:
659-
scratch_shape = functools.partial(jax.ShapeDtypeStruct, dtype=jnp.float32)
660-
m_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
661-
l_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
662-
acc_scratch = scratch_shape((block_b, 1, block_q, head_dim))
663-
out_shape += [m_scratch, l_scratch, acc_scratch]
664-
out_specs += [
665-
pl.BlockSpec(lambda *_: (0, 0, 0, 0), m_scratch.shape),
666-
pl.BlockSpec(lambda *_: (0, 0, 0, 0), l_scratch.shape),
667-
pl.BlockSpec(lambda *_: (0, 0, 0, 0), acc_scratch.shape),
668-
]
654+
m_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE), jnp.float32)
655+
l_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE), jnp.float32)
656+
acc_scratch = pltpu.VMEM((block_b, 1, block_q, head_dim), jnp.float32)
657+
scratch_shapes = [m_scratch, l_scratch, acc_scratch]
669658
else:
670-
out_shape += [None, None, None]
671-
out_specs += [None, None, None]
659+
scratch_shapes = []
672660

673661
if save_residuals:
674662
out_specs = [
@@ -683,6 +671,9 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
683671
(batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32
684672
)
685673
out_shape = (*out_shape, l, m)
674+
else:
675+
out_specs = [*out_specs, None, None]
676+
out_shape = (*out_shape, None, None)
686677

687678
ab_block_spec = (
688679
pl.BlockSpec(ab_index_map, (block_b, 1, block_q, block_k_major))
@@ -745,10 +736,14 @@ def kv_segment_ids_index_map(
745736

746737
o, *aux = pl.pallas_call(
747738
kernel,
739+
grid_spec=pltpu.PrefetchScalarGridSpec(
740+
num_scalar_prefetch=0,
741+
grid=grid,
742+
in_specs=in_specs,
743+
out_specs=out_specs,
744+
scratch_shapes=scratch_shapes,
745+
),
748746
out_shape=out_shape,
749-
in_specs=in_specs,
750-
out_specs=out_specs,
751-
grid=grid,
752747
debug=debug,
753748
mosaic_params=dict(
754749
dimension_semantics=("parallel", "parallel", "parallel", "arbitrary")
@@ -1070,17 +1065,15 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
10701065
k.dtype),
10711066
jax.ShapeDtypeStruct((batch_size, num_heads, kv_seq_len, head_dim),
10721067
v.dtype),
1073-
jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32),
1074-
jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32),
10751068
]
10761069
def dkv_index_map(batch_index, head_index, kv_seq_index, _):
10771070
return (batch_index, head_index, kv_seq_index, 0)
10781071

10791072
dkv_spec = pl.BlockSpec(dkv_index_map, (1, 1, block_k_major, head_dim))
1080-
out_specs = [
1081-
dkv_spec, dkv_spec,
1082-
pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)),
1083-
pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)),
1073+
out_specs = [dkv_spec, dkv_spec]
1074+
scratch_shapes = [
1075+
pltpu.VMEM((block_k_major, head_dim), jnp.float32), # type: ignore
1076+
pltpu.VMEM((block_k_major, head_dim), jnp.float32), # type: ignore
10841077
]
10851078

10861079
kernel = functools.partial(
@@ -1094,12 +1087,16 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _):
10941087
)
10951088
name_scope = f"flash_mha_bwd_dkv_{block_q_major=}_{block_q=}_{block_k_major=}_{block_k=}"
10961089
with jax.named_scope(name_scope):
1097-
dk, dv, _, _ = pl.pallas_call(
1090+
dk, dv = pl.pallas_call(
10981091
kernel,
1099-
in_specs=in_specs, # type: ignore
1092+
grid_spec=pltpu.PrefetchScalarGridSpec(
1093+
num_scalar_prefetch=0,
1094+
grid=grid,
1095+
in_specs=in_specs, # type: ignore
1096+
out_specs=out_specs,
1097+
scratch_shapes=scratch_shapes,
1098+
),
11001099
out_shape=out_shapes,
1101-
out_specs=out_specs,
1102-
grid=grid,
11031100
debug=debug,
11041101
mosaic_params=dict(
11051102
dimension_semantics=(
@@ -1127,8 +1124,8 @@ def _flash_attention_dq_kernel(
11271124
do_tile_ref,
11281125
di_tile_ref,
11291126
dq_tile_ref,
1130-
dq_scratch_ref,
11311127
ds_tile_ref,
1128+
dq_scratch_ref,
11321129
*,
11331130
sm_scale: float,
11341131
causal: bool,
@@ -1414,15 +1411,14 @@ def kv_segment_ids_index_map(
14141411

14151412
out_shapes = [
14161413
jax.ShapeDtypeStruct(q.shape, q.dtype),
1417-
jax.ShapeDtypeStruct((block_q_major, head_dim), jnp.float32),
14181414
jax.ShapeDtypeStruct(ab.shape, ab.dtype) if ab is not None else None,
14191415
]
14201416
dq_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
14211417
out_specs = [
14221418
dq_spec,
1423-
pl.BlockSpec(lambda *_: (0, 0), (block_q_major, head_dim)),
14241419
dab_spec,
14251420
]
1421+
scratch_shapes = [pltpu.VMEM((block_q_major, head_dim), jnp.float32)] # type: ignore
14261422

14271423
kernel = functools.partial(
14281424
_flash_attention_dq_kernel,
@@ -1434,12 +1430,16 @@ def kv_segment_ids_index_map(
14341430
)
14351431
name_scope = f"flash_mha_bwd_dq_{block_q_major=}_{block_k_major=}_{block_k=}"
14361432
with jax.named_scope(name_scope):
1437-
dq, _, ds = pl.pallas_call(
1433+
dq, ds = pl.pallas_call(
14381434
kernel,
1439-
in_specs=in_specs, # type: ignore
1435+
grid_spec=pltpu.PrefetchScalarGridSpec(
1436+
num_scalar_prefetch=0,
1437+
grid=grid,
1438+
in_specs=in_specs, # type: ignore
1439+
out_specs=out_specs, # type: ignore
1440+
scratch_shapes=scratch_shapes,
1441+
),
14401442
out_shape=out_shapes,
1441-
out_specs=out_specs, # type: ignore
1442-
grid=grid,
14431443
debug=debug,
14441444
mosaic_params=dict(
14451445
dimension_semantics=(

0 commit comments

Comments
 (0)