Skip to content

Commit 757db43

Browse files
committed
lint
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 79c804f commit 757db43

File tree

2 files changed

+32
-35
lines changed

2 files changed

+32
-35
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,10 @@ def _attn_fwd_inner_ws(
458458
num_warps=w,
459459
)
460460
)
461-
for BM in [128] #64, 128]
462-
for BN in [128] #64, 128]
463-
for s in [3] #3, 4, 7]
464-
for w in [8] #4, 8]
461+
for BM in [128] # 64, 128]
462+
for BN in [128] # 64, 128]
463+
for s in [3] # 3, 4, 7]
464+
for w in [8] # 4, 8]
465465
]
466466
# TMA, WS, and CompPipe
467467
configsTmaWS = [
@@ -1637,11 +1637,11 @@ def _attn_bwd_dkdv_ws(
16371637
# Load m before computing qk to reduce pipeline stall.
16381638
offs_m = curr_m + tl.arange(0, BLOCK_M1)
16391639
m = tl.load(M + offs_m)
1640-
#with tl.async_task([0]):
1640+
# with tl.async_task([0]):
16411641
# do = tl.load(do_ptrs)
16421642
with tl.async_task([1, 2]):
16431643
qkT = tl.dot(k, qT)
1644-
#dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1644+
# dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
16451645
pT = tl.math.exp2(qkT - m[None, :])
16461646
# Autoregressive masking.
16471647
if MASK:
@@ -1746,10 +1746,10 @@ def keep2(conf):
17461746
"BLOCK_M2": BN,
17471747
"BLOCK_N2": BM,
17481748
},
1749-
num_stages=s, #0 or s,
1749+
num_stages=s, # 0 or s,
17501750
num_warps=w,
1751-
num_buffers_warp_spec=0, #0 or 2,
1752-
num_consumer_groups=0, #0 or 1,
1751+
num_buffers_warp_spec=0, # 0 or 2,
1752+
num_consumer_groups=0, # 0 or 1,
17531753
)
17541754
if has_warp_spec
17551755
else triton.Config(
@@ -1763,10 +1763,10 @@ def keep2(conf):
17631763
num_warps=w,
17641764
)
17651765
)
1766-
for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1767-
for BN in [128] #64, 128]
1768-
for s in [3]#, 4, 7]
1769-
for w in [4]#, 8]
1766+
for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1767+
for BN in [128] # 64, 128]
1768+
for s in [3] # , 4, 7]
1769+
for w in [4] # , 8]
17701770
]
17711771
configsBwdWs = [
17721772
(
@@ -1794,10 +1794,10 @@ def keep2(conf):
17941794
num_warps=w,
17951795
)
17961796
)
1797-
for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1798-
for BN in [128] #[64] #64, 128]
1799-
for s in [3]#, 4, 7]
1800-
for w in [4]#, 8]
1797+
for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
1798+
for BN in [128] # [64] #64, 128]
1799+
for s in [3] # , 4, 7]
1800+
for w in [4] # , 8]
18011801
]
18021802

18031803

@@ -2698,8 +2698,8 @@ def backward(ctx, do):
26982698
BATCH, N_HEAD, N_CTX = q.shape[:3]
26992699
PRE_BLOCK = 128
27002700

2701-
#NUM_WARPS, NUM_STAGES = 4, 5
2702-
#BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
2701+
# NUM_WARPS, NUM_STAGES = 4, 5
2702+
# BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
27032703

27042704
BLK_SLICE_FACTOR = 2
27052705
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
@@ -2720,8 +2720,7 @@ def backward(ctx, do):
27202720
HEAD_DIM=ctx.HEAD_DIM, #
27212721
)
27222722
grid = lambda args: (N_CTX // args["BLOCK_N1"], 1, BATCH * N_HEAD)
2723-
#grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
2724-
print(q.stride(0), q.stride(1), q.stride(2), q.stride(3))
2723+
# grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
27252724
if ctx.bwdVariant == "base":
27262725
_attn_bwd[grid](
27272726
q,
@@ -2740,14 +2739,14 @@ def backward(ctx, do):
27402739
q.stride(3), #
27412740
N_HEAD,
27422741
N_CTX, #
2743-
#BLOCK_M1=BLOCK_M1,
2744-
#BLOCK_N1=BLOCK_N1, #
2745-
#BLOCK_M2=BLOCK_M2,
2746-
#BLOCK_N2=BLOCK_N2, #
2742+
# BLOCK_M1=BLOCK_M1,
2743+
# BLOCK_N1=BLOCK_N1, #
2744+
# BLOCK_M2=BLOCK_M2,
2745+
# BLOCK_N2=BLOCK_N2, #
27472746
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
27482747
HEAD_DIM=ctx.HEAD_DIM, #
2749-
#num_warps=NUM_WARPS, #
2750-
#num_stages=NUM_STAGES, #
2748+
# num_warps=NUM_WARPS, #
2749+
# num_stages=NUM_STAGES, #
27512750
)
27522751
elif ctx.bwdVariant == "ws":
27532752
_attn_bwd_ws[grid](
@@ -2767,14 +2766,14 @@ def backward(ctx, do):
27672766
q.stride(3), #
27682767
N_HEAD,
27692768
N_CTX, #
2770-
#BLOCK_M1=BLOCK_M1,
2771-
#BLOCK_N1=BLOCK_N1, #
2772-
#BLOCK_M2=BLOCK_M2,
2773-
#BLOCK_N2=BLOCK_N2, #
2769+
# BLOCK_M1=BLOCK_M1,
2770+
# BLOCK_N1=BLOCK_N1, #
2771+
# BLOCK_M2=BLOCK_M2,
2772+
# BLOCK_N2=BLOCK_N2, #
27742773
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
27752774
HEAD_DIM=ctx.HEAD_DIM, #
2776-
#num_warps=NUM_WARPS, #
2777-
#num_stages=NUM_STAGES, #
2775+
# num_warps=NUM_WARPS, #
2776+
# num_stages=NUM_STAGES, #
27782777
)
27792778

27802779
return dq, dk, dv, None, None, None, None

tritonbench/operators/flash_attention/operator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,6 @@ def tflops(
470470
BATCH, H, N_CTX, D_HEAD = q.shape
471471
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
472472
tflops = 2 * flops_per_matmul
473-
print("causal, mode: ", self.causal, self.mode)
474-
print("fn_name: ", fn_name, metrics.latency)
475473
if self.causal:
476474
tflops *= 0.5
477475
if self.mode == BenchmarkMode.BWD:

0 commit comments

Comments
 (0)