File tree Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -1513,7 +1513,6 @@ def _attn_bwd_dkdv(
1513
1513
offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1514
1514
m = tl .load (M + offs_m )
1515
1515
qkT = tl .dot (k , qT )
1516
- #dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1517
1516
pT = tl .math .exp2 (qkT - m [None , :])
1518
1517
# Autoregressive masking.
1519
1518
if MASK :
@@ -1930,8 +1929,8 @@ def _attn_bwd_compute(
1930
1929
offs_m = start_m + tl .arange (0 , BLOCK_M2 )
1931
1930
1932
1931
q = tl .load (Q + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1933
- do = tl .load (DO + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1934
1932
dq = tl .zeros ([BLOCK_M2 , HEAD_DIM ], dtype = tl .float32 )
1933
+ do = tl .load (DO + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1935
1934
1936
1935
m = tl .load (M + offs_m )
1937
1936
m = m [:, None ]
You can’t perform that action at this time.
0 commit comments