58
58
)
59
59
60
60
from torchao .float8 import (
61
- Float8LinearConfig ,
62
61
convert_to_float8_training ,
63
62
)
64
63
from torchao .testing .float8 .roofline_utils import (
65
64
get_float8_mem_sympy ,
66
65
get_gemm_time_sympy ,
67
66
)
68
- from torchao .utils import is_sm_at_least_90 , is_sm_at_least_100
69
67
70
68
71
69
class LNLinearSigmoid (torch .nn .Module ):
@@ -155,21 +153,13 @@ def do_matmul(A, B):
155
153
156
154
f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
157
155
158
- if is_sm_at_least_90 () and (not is_sm_at_least_100 ()):
159
- scale_a = torch .ones (M , 1 , device = device )
160
- scale_b = torch .ones (1 , N , device = device )
161
- fast_accum = True # for axiswise
162
- f8_axs_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
163
- else :
164
- f8_axs_time_s = - 1.0
165
-
166
156
# save to cache if needed
167
157
if cache_filename is not None :
168
- cache [key ] = [bf16_time_s , f8_time_s , f8_axs_time_s ]
158
+ cache [key ] = [bf16_time_s , f8_time_s ]
169
159
with open (cache_filename , "w" ) as f :
170
160
json .dump (cache , f )
171
161
172
- return bf16_time_s , f8_time_s , f8_axs_time_s
162
+ return bf16_time_s , f8_time_s
173
163
174
164
175
165
def run (
@@ -229,18 +219,13 @@ def run(
229
219
# gemm microbenchmarks
230
220
"bf16_gemm_s" ,
231
221
"fp8_gemm_s" ,
232
- "fp8_axs_gemm_time_s" ,
233
222
# roofline memory overhead estimates
234
- "fp8_oh_dyn_limit " ,
235
- "fp8_oh_dyn_nolimit " ,
223
+ "fp8_oh_estimated " ,
224
+ "fp8_oh_ideal " ,
236
225
# actual e2e measurements
237
226
"bf16_s" ,
238
227
"fp8_dyn_s" ,
239
- "fp8_dyn_axs_s" ,
240
- # 'fp8_lw_s',
241
228
"fp8_dyn_sp" ,
242
- "fp8_dyn_axs_sp" ,
243
- # 'fp8_lw_sp',
244
229
]
245
230
results = []
246
231
@@ -251,18 +236,17 @@ def run(
251
236
break
252
237
253
238
if gemm_time_strategy == "benchmarks" :
254
- bf16_g1 , f8_g1 , f8_g1_axs = get_gemm_times (
239
+ bf16_g1 , f8_g1 = get_gemm_times (
255
240
M_val , K_val , N_val , True , gemm_cache_filename
256
241
)
257
- bf16_g2 , f8_g2 , f8_g2_axs = get_gemm_times (
242
+ bf16_g2 , f8_g2 = get_gemm_times (
258
243
M_val , N_val , K_val , False , gemm_cache_filename
259
244
)
260
- bf16_g3 , f8_g3 , f8_g3_axs = get_gemm_times (
245
+ bf16_g3 , f8_g3 = get_gemm_times (
261
246
K_val , M_val , N_val , False , gemm_cache_filename
262
247
)
263
248
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
264
249
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
265
- fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
266
250
else :
267
251
assert gemm_time_strategy == "roofline" , "unsupported"
268
252
bf16_time_val = (
@@ -271,8 +255,6 @@ def run(
271
255
fp8_gemm_time_s = (
272
256
fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
273
257
)
274
- # for now, assume axiswise gemm is similar to tensorwise
275
- fp8_axs_gemm_time_s = fp8_gemm_time_s
276
258
277
259
fp8_mem_time_dyn_limit_s = (
278
260
fp8_mem_time_sympy_dyn_limit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
@@ -299,28 +281,6 @@ def run(
299
281
m_fp8_dyn = torch .compile (m_fp8_dyn )
300
282
fp8_dyn_time_actual_s = get_gpu_kernel_time (m_fp8_dyn , x )
301
283
302
- # get the float8 dynamic axiswise scaling gpu kernel time, if supported
303
- # on current hardware
304
- if is_sm_at_least_90 () and (not is_sm_at_least_100 ()):
305
- torch ._dynamo .reset ()
306
- config = Float8LinearConfig .from_recipe_name ("rowwise" )
307
- m_fp8_dyn_axs = convert_to_float8_training (
308
- copy .deepcopy (m_orig ), config = config
309
- )
310
- m_fp8_dyn_axs = torch .compile (m_fp8_dyn_axs )
311
- fp8_dyn_axs_time_actual_s = get_gpu_kernel_time (m_fp8_dyn_axs , x )
312
- else :
313
- fp8_dyn_axs_time_actual_s = - 1.0
314
-
315
- # get the lw recipe scaling gpu kernel time
316
- # TODO(future PR): enable below once basic performance issues
317
- # are fixed
318
- # torch._dynamo.reset()
319
- # config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp")
320
- # m_fp8_lw = convert_to_float8_training(m_orig, config=config)
321
- # m_fp8_lw = torch.compile(m_fp8_lw)
322
- # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
323
-
324
284
results .append (
325
285
[
326
286
M_val ,
@@ -329,18 +289,13 @@ def run(
329
289
# gemm microbenchmarks
330
290
bf16_time_val ,
331
291
fp8_gemm_time_s ,
332
- fp8_axs_gemm_time_s ,
333
292
# roofline overhead estimates
334
293
fp8_mem_time_dyn_limit_s ,
335
294
fp8_mem_time_dyn_nolimit_s ,
336
295
# e2e numbers
337
296
bf16_time_actual_s ,
338
297
fp8_dyn_time_actual_s ,
339
- fp8_dyn_axs_time_actual_s ,
340
- # fp8_lw_time_actual_s,
341
298
bf16_time_actual_s / fp8_dyn_time_actual_s ,
342
- bf16_time_actual_s / fp8_dyn_axs_time_actual_s ,
343
- # bf16_time_actual_s / fp8_lw_time_actual_s,
344
299
]
345
300
)
346
301
0 commit comments