@@ -1644,7 +1644,7 @@ def keep2(conf):
1644
1644
for s in [3 ]#, 4, 7]
1645
1645
for w in [4 ]#, 8]
1646
1646
]
1647
- configsBwd2 = [
1647
+ configsBwdWs = [
1648
1648
(
1649
1649
triton .Config (
1650
1650
{
@@ -1655,8 +1655,8 @@ def keep2(conf):
1655
1655
},
1656
1656
num_stages = s ,
1657
1657
num_warps = w ,
1658
- num_buffers_warp_spec = 0 ,
1659
- num_consumer_groups = 0 ,
1658
+ num_buffers_warp_spec = 2 ,
1659
+ num_consumer_groups = 2 ,
1660
1660
)
1661
1661
if has_warp_spec
1662
1662
else triton .Config (
@@ -1670,8 +1670,8 @@ def keep2(conf):
1670
1670
num_warps = w ,
1671
1671
)
1672
1672
)
1673
- for BM in [32 ] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1674
- for BN in [64 ] #64, 128]
1673
+ for BM in [64 ] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1674
+ for BN in [128 ] #[ 64] #64, 128]
1675
1675
for s in [3 ]#, 4, 7]
1676
1676
for w in [4 ]#, 8]
1677
1677
]
@@ -1922,9 +1922,9 @@ def _attn_bwd(
1922
1922
)
1923
1923
1924
1924
1925
- @triton .autotune (list (filter (keep2 , configsBwd2 )), key = ["N_CTX" ])
1925
+ @triton .autotune (list (filter (keep2 , configsBwdWs )), key = ["N_CTX" ])
1926
1926
@triton .jit
1927
- def _attn_bwd2 (
1927
+ def _attn_bwd_ws (
1928
1928
Q ,
1929
1929
K ,
1930
1930
V ,
@@ -1978,7 +1978,7 @@ def _attn_bwd2(
1978
1978
1979
1979
class _attention_opt (torch .autograd .Function ):
1980
1980
@staticmethod
1981
- def forward (ctx , q , k , v , causal , sm_scale , baseVariant ): # , bwdVariant):
1981
+ def forward (ctx , q , k , v , causal , sm_scale , baseVariant , bwdVariant ):
1982
1982
# shape constraints
1983
1983
HEAD_DIM_Q , HEAD_DIM_K = q .shape [- 1 ], k .shape [- 1 ]
1984
1984
# when v is in float8_e5m2 it is transposed.
@@ -2366,7 +2366,7 @@ def grid_tma_persistent(META):
2366
2366
ctx .sm_scale = sm_scale
2367
2367
ctx .HEAD_DIM = HEAD_DIM_K
2368
2368
ctx .causal = causal
2369
- # ctx.bwdVariant = bwdVariant
2369
+ ctx .bwdVariant = bwdVariant
2370
2370
# If we want to use different variants for bwd, save bwd mode here.
2371
2371
return o
2372
2372
@@ -2385,8 +2385,6 @@ def backward(ctx, do):
2385
2385
2386
2386
#NUM_WARPS, NUM_STAGES = 4, 5
2387
2387
#BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
2388
- NUM_WARPS , NUM_STAGES = 4 , 3
2389
- BLOCK_M1 , BLOCK_N1 , BLOCK_M2 , BLOCK_N2 = 64 , 128 , 128 , 64
2390
2388
2391
2389
BLK_SLICE_FACTOR = 2
2392
2390
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
@@ -2409,7 +2407,7 @@ def backward(ctx, do):
2409
2407
grid = lambda args : (N_CTX // args ["BLOCK_N1" ], 1 , BATCH * N_HEAD )
2410
2408
#grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
2411
2409
print (q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ))
2412
- if True : # ctx.bwdVariant == "base":
2410
+ if ctx .bwdVariant == "base" :
2413
2411
_attn_bwd [grid ](
2414
2412
q ,
2415
2413
arg_k ,
@@ -2436,8 +2434,8 @@ def backward(ctx, do):
2436
2434
#num_warps=NUM_WARPS, #
2437
2435
#num_stages=NUM_STAGES, #
2438
2436
)
2439
- else : #if ctx.bwdVariant == "base2 ":
2440
- _attn_bwd2 [grid ](
2437
+ elif ctx .bwdVariant == "ws " :
2438
+ _attn_bwd_ws [grid ](
2441
2439
q ,
2442
2440
arg_k ,
2443
2441
v ,
0 commit comments