20
20
import triton
21
21
import triton .language as tl
22
22
23
+ from ..utils .triton_utils import AsyncTaskContext
24
+
23
25
from .attention_utils import (
24
26
HAS_EXPLICIT_WS , # guard new tuning configs such as num_consumer_groups
25
27
HAS_TMA_DESC ,
@@ -273,14 +275,14 @@ def _attn_fwd_inner_ws(
273
275
for start_n in tl .range (lo , hi , BLOCK_N ): # , loop_schedule=LOOP_SCHEDULE):
274
276
start_n = tl .multiple_of (start_n , BLOCK_N )
275
277
# -- compute qk ----
276
- with tl . async_task ([0 ]):
278
+ with AsyncTaskContext ([0 ]):
277
279
if ENABLE_TMA :
278
280
k = desc_k .load (
279
281
[start_n .to (tl .int32 ) + (qvk_offset // stride_kn ).to (tl .int32 ), 0 ]
280
282
)
281
283
else :
282
284
k = tl .load (K_block_ptr , boundary_check = (1 ,), padding_option = "zero" )
283
- with tl . async_task ([1 , 2 ]):
285
+ with AsyncTaskContext ([1 , 2 ]):
284
286
if ENABLE_TMA :
285
287
k = tl .trans (k )
286
288
qk = tl .dot (q , k )
@@ -300,7 +302,7 @@ def _attn_fwd_inner_ws(
300
302
# -- update output accumulator --
301
303
acc = acc * alpha [:, None ]
302
304
# update acc
303
- with tl . async_task ([0 ]):
305
+ with AsyncTaskContext ([0 ]):
304
306
if ENABLE_TMA :
305
307
if fp8_v :
306
308
v = desc_v .load (
@@ -312,7 +314,7 @@ def _attn_fwd_inner_ws(
312
314
)
313
315
else :
314
316
v = tl .load (V_block_ptr , boundary_check = (0 ,), padding_option = "zero" )
315
- with tl . async_task ([1 , 2 ]):
317
+ with AsyncTaskContext ([1 , 2 ]):
316
318
if fp8_v :
317
319
if ENABLE_TMA :
318
320
v = tl .trans (v )
@@ -386,20 +388,20 @@ def _attn_fwd_inner_ws_with_dp(
386
388
for start_n in tl .range (lo , hi , BLOCK_N ): # , loop_schedule=LOOP_SCHEDULE):
387
389
start_n = tl .multiple_of (start_n , BLOCK_N )
388
390
# -- compute qk ----
389
- with tl . async_task ([LOAD_K ]):
391
+ with AsyncTaskContext ([LOAD_K ]):
390
392
if ENABLE_TMA :
391
393
k = desc_k .load (
392
394
[start_n .to (tl .int32 ) + (qvk_offset // stride_kn ).to (tl .int32 ), 0 ]
393
395
)
394
396
else :
395
397
k = tl .load (K_block_ptr , boundary_check = (1 ,), padding_option = "zero" )
396
398
397
- with tl . async_task ([FIRST_MMA ]):
399
+ with AsyncTaskContext ([FIRST_MMA ]):
398
400
if ENABLE_TMA : # feeds into gemm
399
401
k = tl .trans (k )
400
402
qk0 = tl .dot (q0 , k )
401
403
qk1 = tl .dot (q1 , k )
402
- with tl . async_task ([FIRST_SOFTMAX ]):
404
+ with AsyncTaskContext ([FIRST_SOFTMAX ]):
403
405
if STAGE == 2 :
404
406
mask = offs_m0 [:, None ] >= (start_n + offs_n [None , :])
405
407
qk0 = qk0 * qk_scale + tl .where (mask , 0 , - 1.0e6 )
@@ -413,22 +415,22 @@ def _attn_fwd_inner_ws_with_dp(
413
415
# -- update m_i and l_i
414
416
alpha0 = tl .math .exp2 (m_i0 - m_ij0 )
415
417
l_i0 = l_i0 * alpha0 + l_ij0
416
- with tl . async_task ([FIRST_CORRECTION ]):
418
+ with AsyncTaskContext ([FIRST_CORRECTION ]):
417
419
if ALPHA_REMAT :
418
420
alpha0_re = tl .math .exp2 (m_i0 - m_ij0 )
419
421
# -- update output accumulator --
420
422
acc0 = acc0 * alpha0_re [:, None ]
421
423
else :
422
424
acc0 = acc0 * alpha0 [:, None ]
423
- with tl . async_task ([FIRST_SOFTMAX ]):
425
+ with AsyncTaskContext ([FIRST_SOFTMAX ]):
424
426
# update acc
425
427
if fp8_v :
426
428
p0 = p0 .to (tl .float8e5 )
427
429
else :
428
430
p0 = p0 .to (tl .bfloat16 )
429
431
# update m_i and l_i
430
432
m_i0 = m_ij0
431
- with tl . async_task ([LAST_SOFTMAX ]):
433
+ with AsyncTaskContext ([LAST_SOFTMAX ]):
432
434
if STAGE == 2 :
433
435
mask = offs_m1 [:, None ] >= (start_n + offs_n [None , :])
434
436
qk1 = qk1 * qk_scale + tl .where (mask , 0 , - 1.0e6 )
@@ -442,22 +444,22 @@ def _attn_fwd_inner_ws_with_dp(
442
444
# -- update m_i and l_i
443
445
alpha1 = tl .math .exp2 (m_i1 - m_ij1 )
444
446
l_i1 = l_i1 * alpha1 + l_ij1
445
- with tl . async_task ([LAST_CORRECTION ]):
447
+ with AsyncTaskContext ([LAST_CORRECTION ]):
446
448
if ALPHA_REMAT :
447
449
alpha1_re = tl .math .exp2 (m_i1 - m_ij1 )
448
450
# -- update output accumulator --
449
451
acc1 = acc1 * alpha1_re [:, None ]
450
452
else :
451
453
acc1 = acc1 * alpha1 [:, None ]
452
- with tl . async_task ([LAST_SOFTMAX ]):
454
+ with AsyncTaskContext ([LAST_SOFTMAX ]):
453
455
# update acc
454
456
if fp8_v :
455
457
p1 = p1 .to (tl .float8e5 )
456
458
else :
457
459
p1 = p1 .to (tl .bfloat16 )
458
460
# update m_i and l_i
459
461
m_i1 = m_ij1
460
- with tl . async_task ([LOAD_V ]):
462
+ with AsyncTaskContext ([LOAD_V ]):
461
463
if ENABLE_TMA :
462
464
if fp8_v :
463
465
v = desc_v .load (
@@ -469,7 +471,7 @@ def _attn_fwd_inner_ws_with_dp(
469
471
)
470
472
else :
471
473
v = tl .load (V_block_ptr , boundary_check = (0 ,), padding_option = "zero" )
472
- with tl . async_task ([LAST_MMA ]):
474
+ with AsyncTaskContext ([LAST_MMA ]):
473
475
if fp8_v :
474
476
if ENABLE_TMA :
475
477
v = tl .trans (v )
@@ -941,7 +943,7 @@ def _attn_fwd_compute_ws(
941
943
qk_scale = sm_scale
942
944
qk_scale *= 1.44269504 # 1/log(2)
943
945
# load q: it will stay in SRAM throughout
944
- with tl . async_task ([0 ]):
946
+ with AsyncTaskContext ([0 ]):
945
947
if ENABLE_TMA :
946
948
q = desc_q .load (
947
949
[(qvk_offset // stride_qm + start_m * BLOCK_M ).to (tl .int32 ), 0 ]
@@ -1011,7 +1013,7 @@ def _attn_fwd_compute_ws(
1011
1013
LOOP_SCHEDULE ,
1012
1014
)
1013
1015
# epilogue
1014
- with tl . async_task ([1 , 2 ]):
1016
+ with AsyncTaskContext ([1 , 2 ]):
1015
1017
m_i += tl .math .log2 (l_i )
1016
1018
acc = acc / l_i [:, None ]
1017
1019
m_ptrs = M + off_hz * N_CTX + offs_m
@@ -1104,14 +1106,14 @@ def _attn_fwd_compute_ws_with_dp(
1104
1106
qk_scale *= 1.44269504 # 1/log(2)
1105
1107
# load q: it will stay in SRAM throughout
1106
1108
# q0 will be BLOCK_M, each kernel invocation will handle 2 * BLOCK_M
1107
- with tl . async_task ([FIRST_LOADQ ]):
1109
+ with AsyncTaskContext ([FIRST_LOADQ ]):
1108
1110
if ENABLE_TMA :
1109
1111
q0 = desc_q .load (
1110
1112
[(qvk_offset // stride_qm + start_m * BLOCK_M ).to (tl .int32 ), 0 ]
1111
1113
)
1112
1114
else :
1113
1115
q0 = tl .load (Q_block_ptr )
1114
- with tl . async_task ([LAST_LOADQ ]):
1116
+ with AsyncTaskContext ([LAST_LOADQ ]):
1115
1117
if ENABLE_TMA :
1116
1118
q1 = desc_q .load (
1117
1119
[
@@ -1214,7 +1216,7 @@ def _attn_fwd_compute_ws_with_dp(
1214
1216
ALPHA_REMAT ,
1215
1217
)
1216
1218
# epilogue
1217
- with tl . async_task ([FIRST_SOFTMAX ]):
1219
+ with AsyncTaskContext ([FIRST_SOFTMAX ]):
1218
1220
m_i0 += tl .math .log2 (l_i0 )
1219
1221
acc0 = acc0 / l_i0 [:, None ]
1220
1222
m_ptrs0 = M + off_hz * N_CTX + offs_m0
@@ -1226,7 +1228,7 @@ def _attn_fwd_compute_ws_with_dp(
1226
1228
)
1227
1229
else :
1228
1230
tl .store (O_block_ptr , acc .to (Out .type .element_ty ))
1229
- with tl . async_task ([LAST_SOFTMAX ]):
1231
+ with AsyncTaskContext ([LAST_SOFTMAX ]):
1230
1232
m_i1 += tl .math .log2 (l_i1 )
1231
1233
acc1 = acc1 / l_i1 [:, None ]
1232
1234
m_ptrs1 = M + off_hz * N_CTX + offs_m1
0 commit comments