@@ -347,11 +347,11 @@ def _flash_attention_kernel_single_batch(
347
347
q_segment_ids_tile_ref ,
348
348
kv_segment_ids_tile_ref , # Input arrays
349
349
o_tile_ref , # Output arrays
350
+ l_ref ,
351
+ m_ref ,
350
352
m_scratch_ref ,
351
353
l_scratch_ref ,
352
354
acc_scratch_ref ,
353
- l_ref : Any | None = None ,
354
- m_ref : Any | None = None ,
355
355
* ,
356
356
causal ,
357
357
sm_scale ,
@@ -496,9 +496,6 @@ def _flash_attention_kernel_single_batch_single_step(
496
496
q_segment_ids_tile_ref ,
497
497
kv_segment_ids_tile_ref , # Input arrays
498
498
o_tile_ref , # Output arrays
499
- m_scratch_ref ,
500
- l_scratch_ref ,
501
- acc_scratch_ref ,
502
499
l_ref : Any | None = None ,
503
500
m_ref : Any | None = None ,
504
501
* ,
@@ -511,8 +508,6 @@ def _flash_attention_kernel_single_batch_single_step(
511
508
block_k_major = k_tile_ref .shape [2 ]
512
509
block_q = q_tile_ref .shape [2 ]
513
510
514
- scratch_refs = (m_scratch_ref , l_scratch_ref , acc_scratch_ref )
515
- assert all (ref is None for ref in scratch_refs )
516
511
assert kv_seq_len == block_k_major == block_k
517
512
518
513
q = q_tile_ref [batch_idx ] # [block_q, head_dim]
@@ -656,19 +651,12 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
656
651
out_specs = [pl .BlockSpec (o_index_map , (block_b , 1 , block_q , head_dim ))]
657
652
658
653
if block_k != kv_seq_len :
659
- scratch_shape = functools .partial (jax .ShapeDtypeStruct , dtype = jnp .float32 )
660
- m_scratch = scratch_shape ((block_b , 1 , block_q , MIN_BLOCK_SIZE ))
661
- l_scratch = scratch_shape ((block_b , 1 , block_q , MIN_BLOCK_SIZE ))
662
- acc_scratch = scratch_shape ((block_b , 1 , block_q , head_dim ))
663
- out_shape += [m_scratch , l_scratch , acc_scratch ]
664
- out_specs += [
665
- pl .BlockSpec (lambda * _ : (0 , 0 , 0 , 0 ), m_scratch .shape ),
666
- pl .BlockSpec (lambda * _ : (0 , 0 , 0 , 0 ), l_scratch .shape ),
667
- pl .BlockSpec (lambda * _ : (0 , 0 , 0 , 0 ), acc_scratch .shape ),
668
- ]
654
+ m_scratch = pltpu .VMEM ((block_b , 1 , block_q , MIN_BLOCK_SIZE ), jnp .float32 )
655
+ l_scratch = pltpu .VMEM ((block_b , 1 , block_q , MIN_BLOCK_SIZE ), jnp .float32 )
656
+ acc_scratch = pltpu .VMEM ((block_b , 1 , block_q , head_dim ), jnp .float32 )
657
+ scratch_shapes = [m_scratch , l_scratch , acc_scratch ]
669
658
else :
670
- out_shape += [None , None , None ]
671
- out_specs += [None , None , None ]
659
+ scratch_shapes = []
672
660
673
661
if save_residuals :
674
662
out_specs = [
@@ -683,6 +671,9 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
683
671
(batch_size , num_heads , q_seq_len , MIN_BLOCK_SIZE ), dtype = jnp .float32
684
672
)
685
673
out_shape = (* out_shape , l , m )
674
+ else :
675
+ out_specs = [* out_specs , None , None ]
676
+ out_shape = (* out_shape , None , None )
686
677
687
678
ab_block_spec = (
688
679
pl .BlockSpec (ab_index_map , (block_b , 1 , block_q , block_k_major ))
@@ -745,10 +736,14 @@ def kv_segment_ids_index_map(
745
736
746
737
o , * aux = pl .pallas_call (
747
738
kernel ,
739
+ grid_spec = pltpu .PrefetchScalarGridSpec (
740
+ num_scalar_prefetch = 0 ,
741
+ grid = grid ,
742
+ in_specs = in_specs ,
743
+ out_specs = out_specs ,
744
+ scratch_shapes = scratch_shapes ,
745
+ ),
748
746
out_shape = out_shape ,
749
- in_specs = in_specs ,
750
- out_specs = out_specs ,
751
- grid = grid ,
752
747
debug = debug ,
753
748
mosaic_params = dict (
754
749
dimension_semantics = ("parallel" , "parallel" , "parallel" , "arbitrary" )
@@ -1070,17 +1065,15 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
1070
1065
k .dtype ),
1071
1066
jax .ShapeDtypeStruct ((batch_size , num_heads , kv_seq_len , head_dim ),
1072
1067
v .dtype ),
1073
- jax .ShapeDtypeStruct ((block_k_major , head_dim ), jnp .float32 ),
1074
- jax .ShapeDtypeStruct ((block_k_major , head_dim ), jnp .float32 ),
1075
1068
]
1076
1069
def dkv_index_map (batch_index , head_index , kv_seq_index , _ ):
1077
1070
return (batch_index , head_index , kv_seq_index , 0 )
1078
1071
1079
1072
dkv_spec = pl .BlockSpec (dkv_index_map , (1 , 1 , block_k_major , head_dim ))
1080
- out_specs = [
1081
- dkv_spec , dkv_spec ,
1082
- pl . BlockSpec ( lambda * _ : ( 0 , 0 ), ( block_k_major , head_dim )),
1083
- pl . BlockSpec ( lambda * _ : ( 0 , 0 ), ( block_k_major , head_dim )),
1073
+ out_specs = [dkv_spec , dkv_spec ]
1074
+ scratch_shapes = [
1075
+ pltpu . VMEM (( block_k_major , head_dim ), jnp . float32 ), # type: ignore
1076
+ pltpu . VMEM (( block_k_major , head_dim ), jnp . float32 ), # type: ignore
1084
1077
]
1085
1078
1086
1079
kernel = functools .partial (
@@ -1094,12 +1087,16 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _):
1094
1087
)
1095
1088
name_scope = f"flash_mha_bwd_dkv_{ block_q_major = } _{ block_q = } _{ block_k_major = } _{ block_k = } "
1096
1089
with jax .named_scope (name_scope ):
1097
- dk , dv , _ , _ = pl .pallas_call (
1090
+ dk , dv = pl .pallas_call (
1098
1091
kernel ,
1099
- in_specs = in_specs , # type: ignore
1092
+ grid_spec = pltpu .PrefetchScalarGridSpec (
1093
+ num_scalar_prefetch = 0 ,
1094
+ grid = grid ,
1095
+ in_specs = in_specs , # type: ignore
1096
+ out_specs = out_specs ,
1097
+ scratch_shapes = scratch_shapes ,
1098
+ ),
1100
1099
out_shape = out_shapes ,
1101
- out_specs = out_specs ,
1102
- grid = grid ,
1103
1100
debug = debug ,
1104
1101
mosaic_params = dict (
1105
1102
dimension_semantics = (
@@ -1127,8 +1124,8 @@ def _flash_attention_dq_kernel(
1127
1124
do_tile_ref ,
1128
1125
di_tile_ref ,
1129
1126
dq_tile_ref ,
1130
- dq_scratch_ref ,
1131
1127
ds_tile_ref ,
1128
+ dq_scratch_ref ,
1132
1129
* ,
1133
1130
sm_scale : float ,
1134
1131
causal : bool ,
@@ -1414,15 +1411,14 @@ def kv_segment_ids_index_map(
1414
1411
1415
1412
out_shapes = [
1416
1413
jax .ShapeDtypeStruct (q .shape , q .dtype ),
1417
- jax .ShapeDtypeStruct ((block_q_major , head_dim ), jnp .float32 ),
1418
1414
jax .ShapeDtypeStruct (ab .shape , ab .dtype ) if ab is not None else None ,
1419
1415
]
1420
1416
dq_spec = pl .BlockSpec (qo_index_map , (1 , 1 , block_q_major , head_dim ))
1421
1417
out_specs = [
1422
1418
dq_spec ,
1423
- pl .BlockSpec (lambda * _ : (0 , 0 ), (block_q_major , head_dim )),
1424
1419
dab_spec ,
1425
1420
]
1421
+ scratch_shapes = [pltpu .VMEM ((block_q_major , head_dim ), jnp .float32 )] # type: ignore
1426
1422
1427
1423
kernel = functools .partial (
1428
1424
_flash_attention_dq_kernel ,
@@ -1434,12 +1430,16 @@ def kv_segment_ids_index_map(
1434
1430
)
1435
1431
name_scope = f"flash_mha_bwd_dq_{ block_q_major = } _{ block_k_major = } _{ block_k = } "
1436
1432
with jax .named_scope (name_scope ):
1437
- dq , _ , ds = pl .pallas_call (
1433
+ dq , ds = pl .pallas_call (
1438
1434
kernel ,
1439
- in_specs = in_specs , # type: ignore
1435
+ grid_spec = pltpu .PrefetchScalarGridSpec (
1436
+ num_scalar_prefetch = 0 ,
1437
+ grid = grid ,
1438
+ in_specs = in_specs , # type: ignore
1439
+ out_specs = out_specs , # type: ignore
1440
+ scratch_shapes = scratch_shapes ,
1441
+ ),
1440
1442
out_shape = out_shapes ,
1441
- out_specs = out_specs , # type: ignore
1442
- grid = grid ,
1443
1443
debug = debug ,
1444
1444
mosaic_params = dict (
1445
1445
dimension_semantics = (
0 commit comments