Skip to content

Commit e3f6db6

Browse files
yf225facebook-github-bot
authored andcommitted
Add accuracy check and fixes for fp8_attention Triton kernels (#276)
Summary: Stacked PRs: * #281 * __->__#276 --- --- --- ### Add accuracy check and fixes for fp8_attention Triton kernels Pull Request resolved: #276 Reviewed By: xuzhao9 Differential Revision: D78183079 Pulled By: yf225 fbshipit-source-id: 0a2a46a120163cc0bcfc03a6a879129c028b1f2e
1 parent 78b71eb commit e3f6db6

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

run.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,28 @@
1010
import sys
1111
from typing import List
1212

13+
# Apply async_task patch for Triton 3.4+ compatibility
14+
import triton.language as tl
15+
1316
from tritonbench.operator_loader import get_op_loader_bench_cls_by_name, is_loader_op
17+
18+
if not hasattr(tl, "async_task"):
19+
20+
class _AsyncTaskContext:
21+
"""A no-op context manager to replace tl.async_task"""
22+
23+
def __init__(self, task_ids):
24+
self.task_ids = task_ids
25+
26+
def __enter__(self):
27+
return self
28+
29+
def __exit__(self, exc_type, exc_val, exc_tb):
30+
return False
31+
32+
# Add async_task to triton.language
33+
tl.async_task = lambda task_ids: _AsyncTaskContext(task_ids)
34+
1435
from tritonbench.operators import load_opbench_by_name
1536
from tritonbench.operators_collection import list_operators_by_collection
1637
from tritonbench.utils.env_utils import is_fbcode

tritonbench/operators/fp8_attention/operator.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD:
7171
self.causal = True
7272
self.requires_grad = not self.tb_args.mode == "fwd_no_grad"
73-
self.sm_scale = 1.3
73+
self.sm_scale = 1.0 / math.sqrt(float(self.D_HEAD))
7474

7575
if self.embedding_dim and self.H != self.embedding_dim // self.D_HEAD:
7676
raise ValueError(
@@ -119,7 +119,7 @@ def triton_preprocess(self, q, k, v):
119119
v,
120120
)
121121

122-
@register_benchmark()
122+
@register_benchmark(baseline=True)
123123
def triton_flash_v2(
124124
self,
125125
q: torch.Tensor,
@@ -129,7 +129,7 @@ def triton_flash_v2(
129129
triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v)
130130
# full fp8 will be enabled if type of q,k,v is fp8
131131
return lambda: triton_attention(
132-
triton_q, triton_k, triton_v, self.causal, self.sm_scale, "base"
132+
triton_q, triton_k, triton_v, self.causal, self.sm_scale, "base_opt"
133133
)
134134

135135
@register_benchmark()
@@ -189,12 +189,14 @@ def get_ctx_vals():
189189
device=self.device,
190190
requires_grad=self.requires_grad,
191191
)
192+
192193
k = torch.randn(
193194
(BATCH, H, N_CTX, D_HEAD),
194195
dtype=torch.float16,
195196
device=self.device,
196197
requires_grad=self.requires_grad,
197198
)
199+
198200
v = torch.randn(
199201
(BATCH, H, N_CTX, D_HEAD),
200202
dtype=torch.float16,
@@ -203,6 +205,42 @@ def get_ctx_vals():
203205
)
204206
yield (q, k, v)
205207

208+
def accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
209+
"""
210+
Check accuracy of FP8 attention implementation against baseline.
211+
212+
FP8 operations have inherently lower precision, so we use relaxed tolerances.
213+
Based on empirical testing, FP8 can introduce differences up to ~2.0.
214+
"""
215+
try:
216+
output = fn()
217+
baseline_output = baseline_fn()
218+
219+
# Convert FP8 outputs to FP16 for comparison
220+
if output.dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
221+
output = output.to(torch.float16)
222+
if baseline_output.dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
223+
baseline_output = baseline_output.to(torch.float16)
224+
225+
# Validate outputs
226+
if torch.isnan(output).any() or torch.isinf(output).any():
227+
return False
228+
if torch.isnan(baseline_output).any() or torch.isinf(baseline_output).any():
229+
return False
230+
if output.shape != baseline_output.shape:
231+
return False
232+
233+
# FP8 attention uses relaxed tolerances due to:
234+
# 1. FP8 quantization of Q, K, V inputs
235+
# 2. FP8 quantization of attention weights (doesn't sum to exactly 1.0)
236+
# 3. Accumulation differences in FP8 GEMM operations
237+
result = torch.allclose(output, baseline_output, atol=2.0, rtol=0.2)
238+
239+
return result
240+
241+
except Exception:
242+
return False
243+
206244
@register_metric()
207245
def flops(
208246
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics

0 commit comments

Comments
 (0)