File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -1640,7 +1640,7 @@ def _attn_bwd_dkdv_ws(
1640
1640
m = tl .load (M + offs_m )
1641
1641
#with tl.async_task([0]):
1642
1642
# do = tl.load(do_ptrs)
1643
- with tl .async_task ([1 ]):
1643
+ with tl .async_task ([1 , 2 ]):
1644
1644
qkT = tl .dot (k , qT )
1645
1645
#dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1646
1646
pT = tl .math .exp2 (qkT - m [None , :])
@@ -1651,7 +1651,7 @@ def _attn_bwd_dkdv_ws(
1651
1651
with tl .async_task ([0 ]):
1652
1652
do = tl .load (do_ptrs )
1653
1653
# Compute dV.
1654
- with tl .async_task ([1 ]):
1654
+ with tl .async_task ([1 , 2 ]):
1655
1655
ppT = pT
1656
1656
ppT = ppT .to (tl .bfloat16 )
1657
1657
dv += tl .dot (ppT , do )
@@ -1708,7 +1708,7 @@ def _attn_bwd_dq_ws(
1708
1708
with tl .async_task ([0 ]):
1709
1709
kT = tl .load (kT_ptrs )
1710
1710
vT = tl .load (vT_ptrs )
1711
- with tl .async_task ([1 ]):
1711
+ with tl .async_task ([1 , 2 ]):
1712
1712
qk = tl .dot (q , kT )
1713
1713
p = tl .math .exp2 (qk - m )
1714
1714
# Autoregressive masking.
You can’t perform that action at this time.
0 commit comments