Skip to content

Commit 3d10ee5

Browse files
committed
minor fix for original variant
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 0908180 commit 3d10ee5

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,7 +1513,6 @@ def _attn_bwd_dkdv(
15131513
offs_m = curr_m + tl.arange(0, BLOCK_M1)
15141514
m = tl.load(M + offs_m)
15151515
qkT = tl.dot(k, qT)
1516-
#dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
15171516
pT = tl.math.exp2(qkT - m[None, :])
15181517
# Autoregressive masking.
15191518
if MASK:
@@ -1930,8 +1929,8 @@ def _attn_bwd_compute(
19301929
offs_m = start_m + tl.arange(0, BLOCK_M2)
19311930

19321931
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
1933-
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
19341932
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
1933+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
19351934

19361935
m = tl.load(M + offs_m)
19371936
m = m[:, None]

0 commit comments

Comments
 (0)