Skip to content

Commit 91cc742

Browse files
authored
Use AsyncTaskContext to replace tl.async_task usage (#288)
1 parent 5b6af5e commit 91cc742

File tree

3 files changed

+44
-40
lines changed

3 files changed

+44
-40
lines changed

run.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,8 @@
1010
import sys
1111
from typing import List
1212

13-
# Apply async_task patch for Triton 3.4+ compatibility
14-
import triton.language as tl
15-
1613
from tritonbench.operator_loader import get_op_loader_bench_cls_by_name, is_loader_op
1714

18-
if not hasattr(tl, "async_task"):
19-
20-
class _AsyncTaskContext:
21-
"""A no-op context manager to replace tl.async_task"""
22-
23-
def __init__(self, task_ids):
24-
self.task_ids = task_ids
25-
26-
def __enter__(self):
27-
return self
28-
29-
def __exit__(self, exc_type, exc_val, exc_tb):
30-
return False
31-
32-
# Add async_task to triton.language
33-
tl.async_task = lambda task_ids: _AsyncTaskContext(task_ids)
34-
3515
from tritonbench.operators import load_opbench_by_name
3616
from tritonbench.operators_collection import list_operators_by_collection
3717
from tritonbench.utils.env_utils import is_fbcode

tritonbench/kernels/triton_fused_attention.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import triton
2121
import triton.language as tl
2222

23+
from ..utils.triton_utils import AsyncTaskContext
24+
2325
from .attention_utils import (
2426
HAS_EXPLICIT_WS, # guard new tuning configs such as num_consumer_groups
2527
HAS_TMA_DESC,
@@ -273,14 +275,14 @@ def _attn_fwd_inner_ws(
273275
for start_n in tl.range(lo, hi, BLOCK_N): # , loop_schedule=LOOP_SCHEDULE):
274276
start_n = tl.multiple_of(start_n, BLOCK_N)
275277
# -- compute qk ----
276-
with tl.async_task([0]):
278+
with AsyncTaskContext([0]):
277279
if ENABLE_TMA:
278280
k = desc_k.load(
279281
[start_n.to(tl.int32) + (qvk_offset // stride_kn).to(tl.int32), 0]
280282
)
281283
else:
282284
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
283-
with tl.async_task([1, 2]):
285+
with AsyncTaskContext([1, 2]):
284286
if ENABLE_TMA:
285287
k = tl.trans(k)
286288
qk = tl.dot(q, k)
@@ -300,7 +302,7 @@ def _attn_fwd_inner_ws(
300302
# -- update output accumulator --
301303
acc = acc * alpha[:, None]
302304
# update acc
303-
with tl.async_task([0]):
305+
with AsyncTaskContext([0]):
304306
if ENABLE_TMA:
305307
if fp8_v:
306308
v = desc_v.load(
@@ -312,7 +314,7 @@ def _attn_fwd_inner_ws(
312314
)
313315
else:
314316
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
315-
with tl.async_task([1, 2]):
317+
with AsyncTaskContext([1, 2]):
316318
if fp8_v:
317319
if ENABLE_TMA:
318320
v = tl.trans(v)
@@ -386,20 +388,20 @@ def _attn_fwd_inner_ws_with_dp(
386388
for start_n in tl.range(lo, hi, BLOCK_N): # , loop_schedule=LOOP_SCHEDULE):
387389
start_n = tl.multiple_of(start_n, BLOCK_N)
388390
# -- compute qk ----
389-
with tl.async_task([LOAD_K]):
391+
with AsyncTaskContext([LOAD_K]):
390392
if ENABLE_TMA:
391393
k = desc_k.load(
392394
[start_n.to(tl.int32) + (qvk_offset // stride_kn).to(tl.int32), 0]
393395
)
394396
else:
395397
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
396398

397-
with tl.async_task([FIRST_MMA]):
399+
with AsyncTaskContext([FIRST_MMA]):
398400
if ENABLE_TMA: # feeds into gemm
399401
k = tl.trans(k)
400402
qk0 = tl.dot(q0, k)
401403
qk1 = tl.dot(q1, k)
402-
with tl.async_task([FIRST_SOFTMAX]):
404+
with AsyncTaskContext([FIRST_SOFTMAX]):
403405
if STAGE == 2:
404406
mask = offs_m0[:, None] >= (start_n + offs_n[None, :])
405407
qk0 = qk0 * qk_scale + tl.where(mask, 0, -1.0e6)
@@ -413,22 +415,22 @@ def _attn_fwd_inner_ws_with_dp(
413415
# -- update m_i and l_i
414416
alpha0 = tl.math.exp2(m_i0 - m_ij0)
415417
l_i0 = l_i0 * alpha0 + l_ij0
416-
with tl.async_task([FIRST_CORRECTION]):
418+
with AsyncTaskContext([FIRST_CORRECTION]):
417419
if ALPHA_REMAT:
418420
alpha0_re = tl.math.exp2(m_i0 - m_ij0)
419421
# -- update output accumulator --
420422
acc0 = acc0 * alpha0_re[:, None]
421423
else:
422424
acc0 = acc0 * alpha0[:, None]
423-
with tl.async_task([FIRST_SOFTMAX]):
425+
with AsyncTaskContext([FIRST_SOFTMAX]):
424426
# update acc
425427
if fp8_v:
426428
p0 = p0.to(tl.float8e5)
427429
else:
428430
p0 = p0.to(tl.bfloat16)
429431
# update m_i and l_i
430432
m_i0 = m_ij0
431-
with tl.async_task([LAST_SOFTMAX]):
433+
with AsyncTaskContext([LAST_SOFTMAX]):
432434
if STAGE == 2:
433435
mask = offs_m1[:, None] >= (start_n + offs_n[None, :])
434436
qk1 = qk1 * qk_scale + tl.where(mask, 0, -1.0e6)
@@ -442,22 +444,22 @@ def _attn_fwd_inner_ws_with_dp(
442444
# -- update m_i and l_i
443445
alpha1 = tl.math.exp2(m_i1 - m_ij1)
444446
l_i1 = l_i1 * alpha1 + l_ij1
445-
with tl.async_task([LAST_CORRECTION]):
447+
with AsyncTaskContext([LAST_CORRECTION]):
446448
if ALPHA_REMAT:
447449
alpha1_re = tl.math.exp2(m_i1 - m_ij1)
448450
# -- update output accumulator --
449451
acc1 = acc1 * alpha1_re[:, None]
450452
else:
451453
acc1 = acc1 * alpha1[:, None]
452-
with tl.async_task([LAST_SOFTMAX]):
454+
with AsyncTaskContext([LAST_SOFTMAX]):
453455
# update acc
454456
if fp8_v:
455457
p1 = p1.to(tl.float8e5)
456458
else:
457459
p1 = p1.to(tl.bfloat16)
458460
# update m_i and l_i
459461
m_i1 = m_ij1
460-
with tl.async_task([LOAD_V]):
462+
with AsyncTaskContext([LOAD_V]):
461463
if ENABLE_TMA:
462464
if fp8_v:
463465
v = desc_v.load(
@@ -469,7 +471,7 @@ def _attn_fwd_inner_ws_with_dp(
469471
)
470472
else:
471473
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
472-
with tl.async_task([LAST_MMA]):
474+
with AsyncTaskContext([LAST_MMA]):
473475
if fp8_v:
474476
if ENABLE_TMA:
475477
v = tl.trans(v)
@@ -941,7 +943,7 @@ def _attn_fwd_compute_ws(
941943
qk_scale = sm_scale
942944
qk_scale *= 1.44269504 # 1/log(2)
943945
# load q: it will stay in SRAM throughout
944-
with tl.async_task([0]):
946+
with AsyncTaskContext([0]):
945947
if ENABLE_TMA:
946948
q = desc_q.load(
947949
[(qvk_offset // stride_qm + start_m * BLOCK_M).to(tl.int32), 0]
@@ -1011,7 +1013,7 @@ def _attn_fwd_compute_ws(
10111013
LOOP_SCHEDULE,
10121014
)
10131015
# epilogue
1014-
with tl.async_task([1, 2]):
1016+
with AsyncTaskContext([1, 2]):
10151017
m_i += tl.math.log2(l_i)
10161018
acc = acc / l_i[:, None]
10171019
m_ptrs = M + off_hz * N_CTX + offs_m
@@ -1104,14 +1106,14 @@ def _attn_fwd_compute_ws_with_dp(
11041106
qk_scale *= 1.44269504 # 1/log(2)
11051107
# load q: it will stay in SRAM throughout
11061108
# q0 will be BLOCK_M, each kernel invocation will handle 2 * BLOCK_M
1107-
with tl.async_task([FIRST_LOADQ]):
1109+
with AsyncTaskContext([FIRST_LOADQ]):
11081110
if ENABLE_TMA:
11091111
q0 = desc_q.load(
11101112
[(qvk_offset // stride_qm + start_m * BLOCK_M).to(tl.int32), 0]
11111113
)
11121114
else:
11131115
q0 = tl.load(Q_block_ptr)
1114-
with tl.async_task([LAST_LOADQ]):
1116+
with AsyncTaskContext([LAST_LOADQ]):
11151117
if ENABLE_TMA:
11161118
q1 = desc_q.load(
11171119
[
@@ -1214,7 +1216,7 @@ def _attn_fwd_compute_ws_with_dp(
12141216
ALPHA_REMAT,
12151217
)
12161218
# epilogue
1217-
with tl.async_task([FIRST_SOFTMAX]):
1219+
with AsyncTaskContext([FIRST_SOFTMAX]):
12181220
m_i0 += tl.math.log2(l_i0)
12191221
acc0 = acc0 / l_i0[:, None]
12201222
m_ptrs0 = M + off_hz * N_CTX + offs_m0
@@ -1226,7 +1228,7 @@ def _attn_fwd_compute_ws_with_dp(
12261228
)
12271229
else:
12281230
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
1229-
with tl.async_task([LAST_SOFTMAX]):
1231+
with AsyncTaskContext([LAST_SOFTMAX]):
12301232
m_i1 += tl.math.log2(l_i1)
12311233
acc1 = acc1 / l_i1[:, None]
12321234
m_ptrs1 = M + off_hz * N_CTX + offs_m1

tritonbench/utils/triton_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
# utils to identify triton versions
22

3+
import triton.language as tl
4+
5+
6+
class AsyncTaskContext:
7+
"""Context manager that dispatches to tl.async_task if available, otherwise no-op."""
8+
9+
def __init__(self, task_ids):
10+
self.task_ids = task_ids
11+
self._has_async_task = hasattr(tl, "async_task")
12+
self._context = None
13+
14+
def __enter__(self):
15+
if self._has_async_task:
16+
self._context = tl.async_task(self.task_ids)
17+
return self._context.__enter__()
18+
return self
19+
20+
def __exit__(self, exc_type, exc_val, exc_tb):
21+
if self._has_async_task and self._context is not None:
22+
return self._context.__exit__(exc_type, exc_val, exc_tb)
23+
return False
24+
325

426
def has_warp_spec():
527
import triton.language as tl

0 commit comments

Comments
 (0)