Skip to content

Commit 78b71eb

Browse files
yf225facebook-github-bot
authored andcommitted
Fix accuracy check for flash_attention kernels (#280)
Summary: Stacked PRs: * #276 * __->__#280 --- --- --- ### Fix accuracy check for flash_attention kernels Pull Request resolved: #280 Reviewed By: oulgen Differential Revision: D78222794 Pulled By: yf225 fbshipit-source-id: 4c104b1defc84a5ffecbe274ea00c590f9463362
1 parent 5a03f18 commit 78b71eb

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,7 +1432,7 @@ def _attn_fwd_base_opt(
14321432
tl.assume(H >= 0)
14331433

14341434
tl.static_assert(BLOCK_N <= HEAD_DIM)
1435-
pid = tl.program_id(0)
1435+
start_m = tl.program_id(0)
14361436
off_hz = tl.program_id(1)
14371437

14381438
# Both base and opt use the same compute function
@@ -1464,7 +1464,7 @@ def _attn_fwd_base_opt(
14641464
stride_om,
14651465
stride_on,
14661466
off_hz,
1467-
pid,
1467+
start_m,
14681468
Z,
14691469
H,
14701470
N_CTX,
@@ -1753,6 +1753,12 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
17531753

17541754
tile_idx = prog_id
17551755

1756+
# Initialize descriptors as None
1757+
desc_q = None
1758+
desc_k = None
1759+
desc_v = None
1760+
desc_o = None
1761+
17561762
if ENABLE_TMA:
17571763
desc_k = tl.make_tensor_descriptor(
17581764
K,

tritonbench/operators/flash_attention/operator.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,19 @@ def __init__(
201201
self.pt2_sdpa = args.pt2_sdpa
202202
self.additional_inputs = args.additional_inputs
203203
self.ragged_shapes = args.ragged_shapes
204-
self.sm_scale = 1.3
204+
# Use standard scale factor: 1/sqrt(head_dim)
205+
self.sm_scale = 1.0 / (self.D_HEAD**0.5)
205206

206-
@register_benchmark()
207+
@register_benchmark(baseline=True)
207208
def aten(
208209
self,
209210
q: torch.Tensor,
210211
k: torch.Tensor,
211212
v: torch.Tensor,
212213
) -> Callable:
213214
def _inner():
214-
M = torch.tril(torch.ones((self.N_CTX, self.N_CTX), device=self.device))
215+
seq_len = q.shape[2]
216+
M = torch.tril(torch.ones((seq_len, seq_len), device=self.device))
215217
p = torch.matmul(q, k.transpose(2, 3)) * self.sm_scale
216218
if self.causal:
217219
p[:, :, M == 0] = float("-inf")
@@ -524,6 +526,28 @@ def causal_mask(b, h, q_idx, kv_idx):
524526

525527
return lambda: flex_attention(q, k, v, block_mask=block_mask)
526528

529+
def accuracy(self, fn, baseline_fn):
530+
"""Override accuracy to use relaxed tolerance for bfloat16."""
531+
output = fn()
532+
baseline_output = baseline_fn()
533+
534+
# Check for NaN values
535+
if torch.isnan(output).any():
536+
return False
537+
538+
try:
539+
# Use relaxed tolerance for bfloat16/float16
540+
# Using atol=2e-2 and rtol=1e-2 to provide some margin
541+
if output.dtype in [torch.bfloat16, torch.float16]:
542+
torch.testing.assert_close(
543+
output, baseline_output, rtol=1e-2, atol=2e-2
544+
)
545+
else:
546+
torch.testing.assert_close(output, baseline_output)
547+
return True
548+
except Exception:
549+
return False
550+
527551
@register_metric(x_only=True)
528552
def flops(
529553
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics

0 commit comments

Comments
 (0)