@@ -458,10 +458,10 @@ def _attn_fwd_inner_ws(
458
458
num_warps = w ,
459
459
)
460
460
)
461
- for BM in [128 ] # 64, 128]
462
- for BN in [128 ] # 64, 128]
463
- for s in [3 ] # 3, 4, 7]
464
- for w in [8 ] # 4, 8]
461
+ for BM in [128 ] # 64, 128]
462
+ for BN in [128 ] # 64, 128]
463
+ for s in [3 ] # 3, 4, 7]
464
+ for w in [8 ] # 4, 8]
465
465
]
466
466
# TMA, WS, and CompPipe
467
467
configsTmaWS = [
@@ -1637,11 +1637,11 @@ def _attn_bwd_dkdv_ws(
1637
1637
# Load m before computing qk to reduce pipeline stall.
1638
1638
offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1639
1639
m = tl .load (M + offs_m )
1640
- #with tl.async_task([0]):
1640
+ # with tl.async_task([0]):
1641
1641
# do = tl.load(do_ptrs)
1642
1642
with tl .async_task ([1 , 2 ]):
1643
1643
qkT = tl .dot (k , qT )
1644
- #dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1644
+ # dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1645
1645
pT = tl .math .exp2 (qkT - m [None , :])
1646
1646
# Autoregressive masking.
1647
1647
if MASK :
@@ -1746,10 +1746,10 @@ def keep2(conf):
1746
1746
"BLOCK_M2" : BN ,
1747
1747
"BLOCK_N2" : BM ,
1748
1748
},
1749
- num_stages = s , # 0 or s,
1749
+ num_stages = s , # 0 or s,
1750
1750
num_warps = w ,
1751
- num_buffers_warp_spec = 0 , # 0 or 2,
1752
- num_consumer_groups = 0 , # 0 or 1,
1751
+ num_buffers_warp_spec = 0 , # 0 or 2,
1752
+ num_consumer_groups = 0 , # 0 or 1,
1753
1753
)
1754
1754
if has_warp_spec
1755
1755
else triton .Config (
@@ -1763,10 +1763,10 @@ def keep2(conf):
1763
1763
num_warps = w ,
1764
1764
)
1765
1765
)
1766
- for BM in [64 ] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1767
- for BN in [128 ] # 64, 128]
1768
- for s in [3 ]# , 4, 7]
1769
- for w in [4 ]# , 8]
1766
+ for BM in [64 ] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1767
+ for BN in [128 ] # 64, 128]
1768
+ for s in [3 ] # , 4, 7]
1769
+ for w in [4 ] # , 8]
1770
1770
]
1771
1771
configsBwdWs = [
1772
1772
(
@@ -1794,10 +1794,10 @@ def keep2(conf):
1794
1794
num_warps = w ,
1795
1795
)
1796
1796
)
1797
- for BM in [64 ] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1798
- for BN in [128 ] # [64] #64, 128]
1799
- for s in [3 ]# , 4, 7]
1800
- for w in [4 ]# , 8]
1797
+ for BM in [64 ] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1798
+ for BN in [128 ] # [64] #64, 128]
1799
+ for s in [3 ] # , 4, 7]
1800
+ for w in [4 ] # , 8]
1801
1801
]
1802
1802
1803
1803
@@ -2698,8 +2698,8 @@ def backward(ctx, do):
2698
2698
BATCH , N_HEAD , N_CTX = q .shape [:3 ]
2699
2699
PRE_BLOCK = 128
2700
2700
2701
- #NUM_WARPS, NUM_STAGES = 4, 5
2702
- #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
2701
+ # NUM_WARPS, NUM_STAGES = 4, 5
2702
+ # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
2703
2703
2704
2704
BLK_SLICE_FACTOR = 2
2705
2705
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
@@ -2720,8 +2720,7 @@ def backward(ctx, do):
2720
2720
HEAD_DIM = ctx .HEAD_DIM , #
2721
2721
)
2722
2722
grid = lambda args : (N_CTX // args ["BLOCK_N1" ], 1 , BATCH * N_HEAD )
2723
- #grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
2724
- print (q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ))
2723
+ # grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
2725
2724
if ctx .bwdVariant == "base" :
2726
2725
_attn_bwd [grid ](
2727
2726
q ,
@@ -2740,14 +2739,14 @@ def backward(ctx, do):
2740
2739
q .stride (3 ), #
2741
2740
N_HEAD ,
2742
2741
N_CTX , #
2743
- #BLOCK_M1=BLOCK_M1,
2744
- #BLOCK_N1=BLOCK_N1, #
2745
- #BLOCK_M2=BLOCK_M2,
2746
- #BLOCK_N2=BLOCK_N2, #
2742
+ # BLOCK_M1=BLOCK_M1,
2743
+ # BLOCK_N1=BLOCK_N1, #
2744
+ # BLOCK_M2=BLOCK_M2,
2745
+ # BLOCK_N2=BLOCK_N2, #
2747
2746
BLK_SLICE_FACTOR = BLK_SLICE_FACTOR , #
2748
2747
HEAD_DIM = ctx .HEAD_DIM , #
2749
- #num_warps=NUM_WARPS, #
2750
- #num_stages=NUM_STAGES, #
2748
+ # num_warps=NUM_WARPS, #
2749
+ # num_stages=NUM_STAGES, #
2751
2750
)
2752
2751
elif ctx .bwdVariant == "ws" :
2753
2752
_attn_bwd_ws [grid ](
@@ -2767,14 +2766,14 @@ def backward(ctx, do):
2767
2766
q .stride (3 ), #
2768
2767
N_HEAD ,
2769
2768
N_CTX , #
2770
- #BLOCK_M1=BLOCK_M1,
2771
- #BLOCK_N1=BLOCK_N1, #
2772
- #BLOCK_M2=BLOCK_M2,
2773
- #BLOCK_N2=BLOCK_N2, #
2769
+ # BLOCK_M1=BLOCK_M1,
2770
+ # BLOCK_N1=BLOCK_N1, #
2771
+ # BLOCK_M2=BLOCK_M2,
2772
+ # BLOCK_N2=BLOCK_N2, #
2774
2773
BLK_SLICE_FACTOR = BLK_SLICE_FACTOR , #
2775
2774
HEAD_DIM = ctx .HEAD_DIM , #
2776
- #num_warps=NUM_WARPS, #
2777
- #num_stages=NUM_STAGES, #
2775
+ # num_warps=NUM_WARPS, #
2776
+ # num_stages=NUM_STAGES, #
2778
2777
)
2779
2778
2780
2779
return dq , dk , dv , None , None , None , None
0 commit comments