Skip to content

Commit e652eb0

Browse files
committed
fix variants
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 152d476 commit e652eb0

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,7 +1644,7 @@ def keep2(conf):
16441644
for s in [3]#, 4, 7]
16451645
for w in [4]#, 8]
16461646
]
1647-
configsBwd2 = [
1647+
configsBwdWs = [
16481648
(
16491649
triton.Config(
16501650
{
@@ -1655,8 +1655,8 @@ def keep2(conf):
16551655
},
16561656
num_stages=s,
16571657
num_warps=w,
1658-
num_buffers_warp_spec=0,
1659-
num_consumer_groups=0,
1658+
num_buffers_warp_spec=2,
1659+
num_consumer_groups=2,
16601660
)
16611661
if has_warp_spec
16621662
else triton.Config(
@@ -1670,8 +1670,8 @@ def keep2(conf):
16701670
num_warps=w,
16711671
)
16721672
)
1673-
for BM in [32] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1674-
for BN in [64] #64, 128]
1673+
for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1674+
for BN in [128] #[64] #64, 128]
16751675
for s in [3]#, 4, 7]
16761676
for w in [4]#, 8]
16771677
]
@@ -1922,9 +1922,9 @@ def _attn_bwd(
19221922
)
19231923

19241924

1925-
@triton.autotune(list(filter(keep2, configsBwd2)), key=["N_CTX"])
1925+
@triton.autotune(list(filter(keep2, configsBwdWs)), key=["N_CTX"])
19261926
@triton.jit
1927-
def _attn_bwd2(
1927+
def _attn_bwd_ws(
19281928
Q,
19291929
K,
19301930
V,
@@ -1978,7 +1978,7 @@ def _attn_bwd2(
19781978

19791979
class _attention_opt(torch.autograd.Function):
19801980
@staticmethod
1981-
def forward(ctx, q, k, v, causal, sm_scale, baseVariant): #, bwdVariant):
1981+
def forward(ctx, q, k, v, causal, sm_scale, baseVariant, bwdVariant):
19821982
# shape constraints
19831983
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
19841984
# when v is in float8_e5m2 it is transposed.
@@ -2366,7 +2366,7 @@ def grid_tma_persistent(META):
23662366
ctx.sm_scale = sm_scale
23672367
ctx.HEAD_DIM = HEAD_DIM_K
23682368
ctx.causal = causal
2369-
#ctx.bwdVariant = bwdVariant
2369+
ctx.bwdVariant = bwdVariant
23702370
# If we want to use different variants for bwd, save bwd mode here.
23712371
return o
23722372

@@ -2385,8 +2385,6 @@ def backward(ctx, do):
23852385

23862386
#NUM_WARPS, NUM_STAGES = 4, 5
23872387
#BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
2388-
NUM_WARPS, NUM_STAGES = 4, 3
2389-
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 128, 128, 64
23902388

23912389
BLK_SLICE_FACTOR = 2
23922390
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
@@ -2409,7 +2407,7 @@ def backward(ctx, do):
24092407
grid = lambda args: (N_CTX // args["BLOCK_N1"], 1, BATCH * N_HEAD)
24102408
#grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
24112409
print(q.stride(0), q.stride(1), q.stride(2), q.stride(3))
2412-
if True: #ctx.bwdVariant == "base":
2410+
if ctx.bwdVariant == "base":
24132411
_attn_bwd[grid](
24142412
q,
24152413
arg_k,
@@ -2436,8 +2434,8 @@ def backward(ctx, do):
24362434
#num_warps=NUM_WARPS, #
24372435
#num_stages=NUM_STAGES, #
24382436
)
2439-
else: #if ctx.bwdVariant == "base2":
2440-
_attn_bwd2[grid](
2437+
elif ctx.bwdVariant == "ws":
2438+
_attn_bwd_ws[grid](
24412439
q,
24422440
arg_k,
24432441
v,

tritonbench/operators/flash_attention/operator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,19 +252,19 @@ def triton_tutorial_flash_v2(
252252
) -> Callable:
253253
# base: do not enable TMA/WarpSpec/CompPipe
254254
return lambda: triton_tutorial_FA2_opt(
255-
q, k, v, self.causal, self.sm_scale, "base"#, "base"
255+
q, k, v, self.causal, self.sm_scale, "base", "base"
256256
)
257257

258258
@register_benchmark()
259-
def triton_tutorial_flash_v2_bwd2(
259+
def triton_tutorial_flash_v2_bwd_ws(
260260
self,
261261
q: torch.Tensor,
262262
k: torch.Tensor,
263263
v: torch.Tensor,
264264
) -> Callable:
265265
# base: do not enable TMA/WarpSpec/CompPipe
266266
return lambda: triton_tutorial_FA2_opt(
267-
q, k, v, self.causal, self.sm_scale, "base", "base2"
267+
q, k, v, self.causal, self.sm_scale, "base", "ws"
268268
)
269269

270270
@register_benchmark(enabled=HAS_CUDA_124)

0 commit comments

Comments
 (0)