@@ -70,7 +70,7 @@ def __init__(
70
70
if self .mode == BenchmarkMode .BWD or self .mode == BenchmarkMode .FWD_BWD :
71
71
self .causal = True
72
72
self .requires_grad = not self .tb_args .mode == "fwd_no_grad"
73
- self .sm_scale = 1.3
73
+ self .sm_scale = 1.0 / math . sqrt ( float ( self . D_HEAD ))
74
74
75
75
if self .embedding_dim and self .H != self .embedding_dim // self .D_HEAD :
76
76
raise ValueError (
@@ -119,7 +119,7 @@ def triton_preprocess(self, q, k, v):
119
119
v ,
120
120
)
121
121
122
- @register_benchmark ()
122
+ @register_benchmark (baseline = True )
123
123
def triton_flash_v2 (
124
124
self ,
125
125
q : torch .Tensor ,
@@ -129,7 +129,7 @@ def triton_flash_v2(
129
129
triton_q , triton_k , triton_v = self .triton_preprocess (q , k , v )
130
130
# full fp8 will be enabled if type of q,k,v is fp8
131
131
return lambda : triton_attention (
132
- triton_q , triton_k , triton_v , self .causal , self .sm_scale , "base "
132
+ triton_q , triton_k , triton_v , self .causal , self .sm_scale , "base_opt "
133
133
)
134
134
135
135
@register_benchmark ()
@@ -189,12 +189,14 @@ def get_ctx_vals():
189
189
device = self .device ,
190
190
requires_grad = self .requires_grad ,
191
191
)
192
+
192
193
k = torch .randn (
193
194
(BATCH , H , N_CTX , D_HEAD ),
194
195
dtype = torch .float16 ,
195
196
device = self .device ,
196
197
requires_grad = self .requires_grad ,
197
198
)
199
+
198
200
v = torch .randn (
199
201
(BATCH , H , N_CTX , D_HEAD ),
200
202
dtype = torch .float16 ,
@@ -203,6 +205,42 @@ def get_ctx_vals():
203
205
)
204
206
yield (q , k , v )
205
207
208
+ def accuracy (self , fn : Callable , baseline_fn : Callable ) -> bool :
209
+ """
210
+ Check accuracy of FP8 attention implementation against baseline.
211
+
212
+ FP8 operations have inherently lower precision, so we use relaxed tolerances.
213
+ Based on empirical testing, FP8 can introduce differences up to ~2.0.
214
+ """
215
+ try :
216
+ output = fn ()
217
+ baseline_output = baseline_fn ()
218
+
219
+ # Convert FP8 outputs to FP16 for comparison
220
+ if output .dtype in [torch .float8_e5m2 , torch .float8_e4m3fn ]:
221
+ output = output .to (torch .float16 )
222
+ if baseline_output .dtype in [torch .float8_e5m2 , torch .float8_e4m3fn ]:
223
+ baseline_output = baseline_output .to (torch .float16 )
224
+
225
+ # Validate outputs
226
+ if torch .isnan (output ).any () or torch .isinf (output ).any ():
227
+ return False
228
+ if torch .isnan (baseline_output ).any () or torch .isinf (baseline_output ).any ():
229
+ return False
230
+ if output .shape != baseline_output .shape :
231
+ return False
232
+
233
+ # FP8 attention uses relaxed tolerances due to:
234
+ # 1. FP8 quantization of Q, K, V inputs
235
+ # 2. FP8 quantization of attention weights (doesn't sum to exactly 1.0)
236
+ # 3. Accumulation differences in FP8 GEMM operations
237
+ result = torch .allclose (output , baseline_output , atol = 2.0 , rtol = 0.2 )
238
+
239
+ return result
240
+
241
+ except Exception :
242
+ return False
243
+
206
244
@register_metric ()
207
245
def flops (
208
246
self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
0 commit comments