Skip to content

Commit 8d2ca2e

Browse files
committed
use [1,2]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 50db08c commit 8d2ca2e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,7 +1640,7 @@ def _attn_bwd_dkdv_ws(
16401640
m = tl.load(M + offs_m)
16411641
#with tl.async_task([0]):
16421642
# do = tl.load(do_ptrs)
1643-
with tl.async_task([1]):
1643+
with tl.async_task([1, 2]):
16441644
qkT = tl.dot(k, qT)
16451645
#dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
16461646
pT = tl.math.exp2(qkT - m[None, :])
@@ -1651,7 +1651,7 @@ def _attn_bwd_dkdv_ws(
16511651
with tl.async_task([0]):
16521652
do = tl.load(do_ptrs)
16531653
# Compute dV.
1654-
with tl.async_task([1]):
1654+
with tl.async_task([1, 2]):
16551655
ppT = pT
16561656
ppT = ppT.to(tl.bfloat16)
16571657
dv += tl.dot(ppT, do)
@@ -1708,7 +1708,7 @@ def _attn_bwd_dq_ws(
17081708
with tl.async_task([0]):
17091709
kT = tl.load(kT_ptrs)
17101710
vT = tl.load(vT_ptrs)
1711-
with tl.async_task([1]):
1711+
with tl.async_task([1, 2]):
17121712
qk = tl.dot(q, kT)
17131713
p = tl.math.exp2(qk - m)
17141714
# Autoregressive masking.

0 commit comments

Comments
 (0)