58
58
)
59
59
60
60
from torchao .float8 import (
61
- CastConfig ,
62
61
Float8LinearConfig ,
63
- ScalingType ,
64
62
convert_to_float8_training ,
65
63
)
66
64
from torchao .float8 .roofline_utils import (
@@ -219,24 +217,6 @@ def run(
219
217
scaling_type_weight = "dynamic" ,
220
218
scaling_type_grad_output = "dynamic" ,
221
219
)
222
- fp8_mem_time_sympy_del_limit = get_float8_mem_sympy (
223
- M ,
224
- K ,
225
- N ,
226
- model_torch_compile_limitations = True ,
227
- scaling_type_input = "delayed" ,
228
- scaling_type_weight = "delayed" ,
229
- scaling_type_grad_output = "delayed" ,
230
- )
231
- fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy (
232
- M ,
233
- K ,
234
- N ,
235
- model_torch_compile_limitations = False ,
236
- scaling_type_input = "delayed" ,
237
- scaling_type_weight = "delayed" ,
238
- scaling_type_grad_output = "delayed" ,
239
- )
240
220
241
221
if gemm_time_strategy == "roofline" :
242
222
bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
@@ -258,16 +238,12 @@ def run(
258
238
# roofline memory overhead estimates
259
239
"fp8_oh_dyn_limit" ,
260
240
"fp8_oh_dyn_nolimit" ,
261
- "fp8_oh_del_limit" ,
262
- "fp8_oh_del_nolimit" ,
263
241
# actual e2e measurements
264
242
"bf16_s" ,
265
243
"fp8_dyn_s" ,
266
- "fp8_del_s" ,
267
244
"fp8_dyn_axs_s" ,
268
245
# 'fp8_lw_s',
269
246
"fp8_dyn_sp" ,
270
- "fp8_del_sp" ,
271
247
"fp8_dyn_axs_sp" ,
272
248
# 'fp8_lw_sp',
273
249
]
@@ -309,12 +285,6 @@ def run(
309
285
fp8_mem_time_dyn_nolimit_s = (
310
286
fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
311
287
)
312
- fp8_mem_time_del_limit_s = (
313
- fp8_mem_time_sympy_del_limit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
314
- )
315
- fp8_mem_time_del_nolimit_s = (
316
- fp8_mem_time_sympy_del_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
317
- )
318
288
319
289
# create the model
320
290
m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
@@ -333,19 +303,6 @@ def run(
333
303
m_fp8_dyn = torch .compile (m_fp8_dyn )
334
304
fp8_dyn_time_actual_s = get_gpu_kernel_time (m_fp8_dyn , x )
335
305
336
- # get the float8 delayed scaling gpu kernel time
337
- torch ._dynamo .reset ()
338
- config = Float8LinearConfig (
339
- enable_amax_init = False ,
340
- enable_pre_and_post_forward = False ,
341
- cast_config_input = CastConfig (scaling_type = ScalingType .DELAYED ),
342
- cast_config_weight = CastConfig (scaling_type = ScalingType .DELAYED ),
343
- cast_config_grad_output = CastConfig (scaling_type = ScalingType .DELAYED ),
344
- )
345
- m_fp8_del = convert_to_float8_training (copy .deepcopy (m_orig ), config = config )
346
- m_fp8_del = torch .compile (m_fp8_del )
347
- fp8_del_time_actual_s = get_gpu_kernel_time (m_fp8_del , x )
348
-
349
306
# get the float8 dynamic axiswise scaling gpu kernel time
350
307
torch ._dynamo .reset ()
351
308
config = Float8LinearConfig .from_recipe_name ("rowwise" )
@@ -374,16 +331,12 @@ def run(
374
331
# roofline overhead estimates
375
332
fp8_mem_time_dyn_limit_s ,
376
333
fp8_mem_time_dyn_nolimit_s ,
377
- fp8_mem_time_del_limit_s ,
378
- fp8_mem_time_del_nolimit_s ,
379
334
# e2e numbers
380
335
bf16_time_actual_s ,
381
336
fp8_dyn_time_actual_s ,
382
- fp8_del_time_actual_s ,
383
337
fp8_dyn_axs_time_actual_s ,
384
338
# fp8_lw_time_actual_s,
385
339
bf16_time_actual_s / fp8_dyn_time_actual_s ,
386
- bf16_time_actual_s / fp8_del_time_actual_s ,
387
340
bf16_time_actual_s / fp8_dyn_axs_time_actual_s ,
388
341
# bf16_time_actual_s / fp8_lw_time_actual_s,
389
342
]
0 commit comments