Skip to content

Commit d835d3d

Browse files
y-sqfacebook-github-bot
authored andcommitted
fix flash attention benchmark failures
Reviewed By: xuzhao9 Differential Revision: D74487012 fbshipit-source-id: 4e5114f56b5115959a50b56b212a142adeef6b5d
1 parent 8d6e8b5 commit d835d3d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tritonbench/operators/flash_attention/operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,11 @@ def get_ctx_vals():
545545
shapes = ctx_vals
546546
requires_grad = True
547547
for shape in shapes:
548-
BATCH, H, N_CTX, N_CTX_KV, D_HEAD = shape
548+
if len(shape) == 5:
549+
BATCH, H, N_CTX, N_CTX_KV, D_HEAD = shape
550+
else:
551+
BATCH, H, N_CTX, D_HEAD = shape
552+
N_CTX_KV = N_CTX
549553
q = torch.randn(
550554
(BATCH, H, N_CTX, D_HEAD),
551555
dtype=self.dtype,

0 commit comments

Comments
 (0)