4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ # This is a convenience script to profile fwd+bwd of individual layers with
8
+ # float8 training or mx training on a single GPU.
9
+
7
10
import copy
8
11
import functools
9
12
import io
38
41
39
42
from torchao .float8 .config import (
40
43
Float8LinearConfig ,
41
- ScalingType ,
42
44
)
43
45
from torchao .float8 .float8_linear_utils import (
44
46
convert_to_float8_training ,
45
47
)
46
- from torchao .testing .float8 .test_utils import get_test_float8_linear_config
48
+ from torchao .prototype .mx_formats .config import MXLinearConfig
49
+ from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
50
+ from torchao .prototype .mx_formats .mx_tensor import MXTensor
47
51
48
52
# don't truncate long kernel names
49
53
pd .options .display .max_colwidth = 100
@@ -257,7 +261,6 @@ def profile_function(
257
261
# set up AC for max(abs(tensor))
258
262
# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts
259
263
ops_to_save = [
260
- torch .ops .aten .abs .default ,
261
264
torch .ops .aten .max .default ,
262
265
]
263
266
@@ -275,50 +278,52 @@ def policy_fn(ctx, op, *args, **kwargs):
275
278
def main (
276
279
profile_path_prefix : pathlib .Path ,
277
280
compile : bool = True ,
278
- scaling_type_input : str = "dynamic" ,
279
- scaling_type_weight : str = "dynamic" ,
280
- scaling_type_grad_output : str = "dynamic" ,
281
- recipe_name : Optional [str ] = None ,
281
+ float8_recipe_name : Optional [str ] = None ,
282
+ mx_recipe_name : Optional [str ] = None ,
282
283
model_type : str = "linear" ,
283
- dtype_filter : str = "both" ,
284
- add_inductor_metadata_to_trace : bool = True ,
284
+ experiment_filter : str = "both" ,
285
+ add_inductor_metadata_to_trace : bool = False ,
285
286
enable_activation_checkpointing : bool = False ,
287
+ mode_filter : str = "fwd_bwd" ,
288
+ forward_only : bool = False ,
286
289
):
287
290
assert model_type in (
288
291
"linear" ,
289
292
"ln_linear" ,
290
293
"norm_ffn_norm" ,
291
294
"norm_ffn_norm_small" ,
292
295
), "unsupported"
293
- assert dtype_filter in ("both" , "float8" , "bfloat16" )
294
-
295
- scaling_type_input = ScalingType (scaling_type_input )
296
- scaling_type_weight = ScalingType (scaling_type_weight )
297
- scaling_type_grad_output = ScalingType (scaling_type_grad_output )
298
-
299
- if recipe_name is None :
300
- config = get_test_float8_linear_config (
301
- scaling_type_input ,
302
- scaling_type_weight ,
303
- scaling_type_grad_output ,
304
- emulate = False ,
305
- )
306
- elif recipe_name is not None :
307
- config = Float8LinearConfig .from_recipe_name (recipe_name )
308
-
309
- scaling_repr = "_" .join (
310
- [
311
- s .short_str ()
312
- for s in (scaling_type_input , scaling_type_weight , scaling_type_grad_output )
313
- ]
314
- )
296
+ assert experiment_filter in (
297
+ "both" ,
298
+ "lowp" ,
299
+ "ref" ,
300
+ ), "experiment_filter must be one of `both`, `lowp`, `ref`"
301
+ assert mode_filter in (
302
+ "fwd_bwd" ,
303
+ "fwd" ,
304
+ "cast_only" ,
305
+ ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`"
306
+ if mode_filter == "cast_only" :
307
+ assert experiment_filter == "lowp" , "unsupported"
308
+
309
+ assert not (
310
+ float8_recipe_name is not None and mx_recipe_name is not None
311
+ ), "either float8_recipe_name or mx_recipe_name can be specified, but not both"
312
+
313
+ if float8_recipe_name is None and mx_recipe_name is None :
314
+ config = Float8LinearConfig ()
315
+ elif float8_recipe_name is not None :
316
+ config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
317
+ elif mx_recipe_name is not None :
318
+ config = MXLinearConfig .from_recipe_name (mx_recipe_name )
315
319
316
320
print (f"Compile is set to | { compile } " )
317
321
print (f"model_type is set to | { model_type } " )
318
- print (f"scaling_repr is set to | { scaling_repr } " )
319
322
print (
320
323
f"enable_activation_checkpointing is set to { enable_activation_checkpointing } "
321
324
)
325
+ print (f"mode_filter is set to { mode_filter } " )
326
+ print (f"config: { config } " )
322
327
323
328
device = "cuda"
324
329
ref_dtype = torch .bfloat16
@@ -359,36 +364,58 @@ def main(
359
364
360
365
m_ref = m_ref .to (device ).to (ref_dtype )
361
366
362
- m_float8 = copy .deepcopy (m_ref )
363
- convert_to_float8_training (m_float8 , config = config )
367
+ # get gradient shape
368
+ with torch .no_grad ():
369
+ _ = m_ref (input_tensor )
370
+ grad_output = torch .ones_like (_ )
371
+
372
+ m_lowp = copy .deepcopy (m_ref )
373
+ if mx_recipe_name is None :
374
+ convert_to_float8_training (m_lowp , config = config )
375
+ else :
376
+ swap_linear_with_mx_linear (m_lowp , config = config )
377
+
378
+ # this function is only used for cast_only
379
+ to_mx_func = MXTensor .to_mx
380
+
381
+ print ("m_ref" , m_ref )
382
+ print ("m_lowp" , m_lowp )
383
+ print ("input_tensor.shape" , input_tensor .shape )
384
+ print ("grad_output.shape" , grad_output .shape )
385
+ print ()
364
386
365
387
def ref_forw_backward (x ):
388
+ assert mode_filter != "cast_only" , "unsupported"
366
389
if enable_activation_checkpointing :
367
390
out = checkpoint (m_ref , x , use_reentrant = False , context_fn = context_fn )
368
391
else :
369
392
out = m_ref (x )
370
- out .sum ().backward ()
393
+ if mode_filter == "fwd_bwd" :
394
+ out .backward (grad_output )
395
+
396
+ def lowp_forw_backward_wrapper (x ):
397
+ if mode_filter == "cast_only" :
398
+ # just cast and return early
399
+ _input_tensor_mx = to_mx_func (
400
+ input_tensor ,
401
+ config .elem_dtype ,
402
+ config .block_size ,
403
+ gemm_kernel_choice = config .gemm_kernel_choice ,
404
+ )
405
+ return
371
406
372
- def float8_forw (x ):
373
407
if enable_activation_checkpointing :
374
- out = checkpoint (m_float8 , x , use_reentrant = False , context_fn = context_fn )
408
+ out = checkpoint (m_lowp , x , use_reentrant = False , context_fn = context_fn )
375
409
else :
376
- out = m_float8 (x )
377
- return out
378
-
379
- def float8_forw_backward_wrapper (x ):
380
- # TODO(future PR): this wrapper is for delayed scaling, we can clean it
381
- # up now that delayed scaling is deprecated.
382
- out = float8_forw (x )
383
-
384
- # out.sum().backward() is also not torch.compile fullgraph
385
- # friendly
386
- with record_function ("backward" ):
387
- out .sum ().backward ()
410
+ out = m_lowp (x )
411
+ if mode_filter == "fwd_bwd" :
412
+ with record_function ("backward" ):
413
+ out .backward (grad_output )
388
414
389
415
if compile :
390
416
m_ref = torch .compile (m_ref , fullgraph = True )
391
- float8_forw = torch .compile (float8_forw , fullgraph = True )
417
+ m_lowp = torch .compile (m_lowp , fullgraph = True )
418
+ to_mx_func = torch .compile (to_mx_func , fullgraph = True )
392
419
393
420
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
394
421
# to populate triton kernel bandwidth further down in the script
@@ -398,15 +425,21 @@ def float8_forw_backward_wrapper(x):
398
425
else :
399
426
f = io .StringIO ()
400
427
context = redirect_stdout (f )
428
+
429
+ # if we are skipping forward, enable torch.no_grad()
430
+ maybe_no_grad_context = (
431
+ torch .no_grad () if mode_filter != "fwd_bwd" else nullcontext ()
432
+ )
433
+
401
434
try :
402
- with context :
435
+ with context , maybe_no_grad_context :
403
436
profile_iters = 5
404
- ref_times , float8_times = None , None
437
+ ref_times , lowp_times = None , None
405
438
data = []
406
439
407
440
num_leaf_tensors = 1 + len (list (m_ref .parameters ()))
408
441
409
- if dtype_filter != "float8 " :
442
+ if experiment_filter != "lowp " :
410
443
# Profile Reference Model
411
444
print ("profiling ref" )
412
445
ref_trace_suffix = f"_{ model_type } _ref_compile_{ compile } .json"
@@ -452,50 +485,46 @@ def float8_forw_backward_wrapper(x):
452
485
]
453
486
)
454
487
455
- if dtype_filter != "bfloat16" :
456
- # Profile Float8 Model
457
- print ("profiling float8" )
458
- float8_trace_suffix = (
459
- f"_{ model_type } _float8_compile_{ compile } _{ scaling_repr } .json"
460
- )
461
- float8_log_suffix = (
462
- f"_{ model_type } _float8_compile_{ compile } _{ scaling_repr } .txt"
463
- )
464
- trace_float8_path = profile_path_prefix + float8_trace_suffix
465
- log_float8_path = profile_path_prefix + float8_log_suffix
466
- trace_float8_modified_path = trace_float8_path .replace (
488
+ if experiment_filter != "ref" :
489
+ # Profile lowp Model
490
+ print ("profiling lowp" )
491
+ lowp_trace_suffix = f"_{ model_type } _lowp_compile_{ compile } .json"
492
+ lowp_log_suffix = f"_{ model_type } _lowp_compile_{ compile } .txt"
493
+ trace_lowp_path = profile_path_prefix + lowp_trace_suffix
494
+ log_lowp_path = profile_path_prefix + lowp_log_suffix
495
+ trace_lowp_modified_path = trace_lowp_path .replace (
467
496
".json" , "_modified.json"
468
497
)
469
498
profile_config = ProfileConfig (
470
- trace_float8_path ,
471
- log_float8_path ,
472
- trace_float8_modified_path ,
473
- float8_trace_suffix ,
499
+ trace_lowp_path ,
500
+ log_lowp_path ,
501
+ trace_lowp_modified_path ,
502
+ lowp_trace_suffix ,
474
503
iters = profile_iters ,
475
504
warmup_iters = 2 ,
476
505
sync = True ,
477
506
)
478
507
p = profile_function (
479
508
profile_config ,
480
- float8_forw_backward_wrapper ,
509
+ lowp_forw_backward_wrapper ,
481
510
add_inductor_metadata_to_trace ,
482
511
input_tensor ,
483
512
)
484
- print (f"saved profiling trace to { trace_float8_path } " )
513
+ print (f"saved profiling trace to { trace_lowp_path } " )
485
514
if add_inductor_metadata_to_trace :
486
- print (f"saved torch logs to { log_float8_path } " )
487
- print (f"saved modified trace to { trace_float8_modified_path } " )
488
- float8_times = profiler_output_to_filtered_time_by_kernel_name (
515
+ print (f"saved torch logs to { log_lowp_path } " )
516
+ print (f"saved modified trace to { trace_lowp_modified_path } " )
517
+ lowp_times = profiler_output_to_filtered_time_by_kernel_name (
489
518
p , profile_iters , num_leaf_tensors
490
519
)
491
520
total_time_ms = (
492
- sum (v for v in float8_times .values ()) / 1e3 / profile_iters
521
+ sum (v for v in lowp_times .values ()) / 1e3 / profile_iters
493
522
)
494
- for k , v in float8_times .items ():
523
+ for k , v in lowp_times .items ():
495
524
v_ms = v / 1e3 / profile_iters
496
525
data .append (
497
526
[
498
- "1_float8 " ,
527
+ "1_lowp " ,
499
528
k ,
500
529
kernel_name_to_category (k ),
501
530
v / 1e3 / profile_iters ,
@@ -509,6 +538,7 @@ def float8_forw_backward_wrapper(x):
509
538
# print the redirected stdout back to regular stdout
510
539
print (f .getvalue ())
511
540
541
+ # TODO(future PR): this seems to no longer work, fix it or delete it
512
542
if os .environ .get ("TORCHINDUCTOR_PROFILE" , "" ) != "" :
513
543
# populate the triton kernel bandwidth
514
544
for line in f .getvalue ().split ("\n " ):
@@ -546,13 +576,13 @@ def float8_forw_backward_wrapper(x):
546
576
fill_value = 0 ,
547
577
margins = True ,
548
578
)
549
- # drop last row, which has totals across ref + float8 which does not make sense
579
+ # drop last row, which has totals across ref + lowp which does not make sense
550
580
df_p = df_p [:- 1 ]
551
581
df_p = df_p .transpose ()
552
582
553
- if dtype_filter == "both" :
554
- df_p ["f8_div_ref " ] = df_p ["1_float8 " ] / df_p ["0_ref" ]
555
- df_p ["ref_div_f8 " ] = df_p ["0_ref" ] / df_p ["1_float8 " ]
583
+ if experiment_filter == "both" :
584
+ df_p ["lowp_div_ref " ] = df_p ["1_lowp " ] / df_p ["0_ref" ]
585
+ df_p ["ref_div_lowp " ] = df_p ["0_ref" ] / df_p ["1_lowp " ]
556
586
557
587
print ("\n Summary of time (ms) by kernel category\n \n " , df_p )
558
588
0 commit comments