Skip to content

Commit 8d38814

Browse files
authored
add MX support to lowp training profiling script (#1765)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 2a3fbff commit 8d38814

File tree

3 files changed

+157
-108
lines changed

3 files changed

+157
-108
lines changed

benchmarks/float8/profile_linear_float8.py renamed to benchmarks/float8/profile_lowp_training.py

Lines changed: 111 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# This is a convenience script to profile fwd+bwd of individual layers with
8+
# float8 training or mx training on a single GPU.
9+
710
import copy
811
import functools
912
import io
@@ -38,12 +41,13 @@
3841

3942
from torchao.float8.config import (
4043
Float8LinearConfig,
41-
ScalingType,
4244
)
4345
from torchao.float8.float8_linear_utils import (
4446
convert_to_float8_training,
4547
)
46-
from torchao.testing.float8.test_utils import get_test_float8_linear_config
48+
from torchao.prototype.mx_formats.config import MXLinearConfig
49+
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
50+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
4751

4852
# don't truncate long kernel names
4953
pd.options.display.max_colwidth = 100
@@ -257,7 +261,6 @@ def profile_function(
257261
# set up AC for max(abs(tensor))
258262
# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts
259263
ops_to_save = [
260-
torch.ops.aten.abs.default,
261264
torch.ops.aten.max.default,
262265
]
263266

@@ -275,50 +278,52 @@ def policy_fn(ctx, op, *args, **kwargs):
275278
def main(
276279
profile_path_prefix: pathlib.Path,
277280
compile: bool = True,
278-
scaling_type_input: str = "dynamic",
279-
scaling_type_weight: str = "dynamic",
280-
scaling_type_grad_output: str = "dynamic",
281-
recipe_name: Optional[str] = None,
281+
float8_recipe_name: Optional[str] = None,
282+
mx_recipe_name: Optional[str] = None,
282283
model_type: str = "linear",
283-
dtype_filter: str = "both",
284-
add_inductor_metadata_to_trace: bool = True,
284+
experiment_filter: str = "both",
285+
add_inductor_metadata_to_trace: bool = False,
285286
enable_activation_checkpointing: bool = False,
287+
mode_filter: str = "fwd_bwd",
288+
forward_only: bool = False,
286289
):
287290
assert model_type in (
288291
"linear",
289292
"ln_linear",
290293
"norm_ffn_norm",
291294
"norm_ffn_norm_small",
292295
), "unsupported"
293-
assert dtype_filter in ("both", "float8", "bfloat16")
294-
295-
scaling_type_input = ScalingType(scaling_type_input)
296-
scaling_type_weight = ScalingType(scaling_type_weight)
297-
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
298-
299-
if recipe_name is None:
300-
config = get_test_float8_linear_config(
301-
scaling_type_input,
302-
scaling_type_weight,
303-
scaling_type_grad_output,
304-
emulate=False,
305-
)
306-
elif recipe_name is not None:
307-
config = Float8LinearConfig.from_recipe_name(recipe_name)
308-
309-
scaling_repr = "_".join(
310-
[
311-
s.short_str()
312-
for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output)
313-
]
314-
)
296+
assert experiment_filter in (
297+
"both",
298+
"lowp",
299+
"ref",
300+
), "experiment_filter must be one of `both`, `lowp`, `ref`"
301+
assert mode_filter in (
302+
"fwd_bwd",
303+
"fwd",
304+
"cast_only",
305+
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`"
306+
if mode_filter == "cast_only":
307+
assert experiment_filter == "lowp", "unsupported"
308+
309+
assert not (
310+
float8_recipe_name is not None and mx_recipe_name is not None
311+
), "either float8_recipe_name or mx_recipe_name can be specified, but not both"
312+
313+
if float8_recipe_name is None and mx_recipe_name is None:
314+
config = Float8LinearConfig()
315+
elif float8_recipe_name is not None:
316+
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
317+
elif mx_recipe_name is not None:
318+
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
315319

316320
print(f"Compile is set to | {compile}")
317321
print(f"model_type is set to | {model_type}")
318-
print(f"scaling_repr is set to | {scaling_repr}")
319322
print(
320323
f"enable_activation_checkpointing is set to {enable_activation_checkpointing}"
321324
)
325+
print(f"mode_filter is set to {mode_filter}")
326+
print(f"config: {config}")
322327

323328
device = "cuda"
324329
ref_dtype = torch.bfloat16
@@ -359,36 +364,58 @@ def main(
359364

360365
m_ref = m_ref.to(device).to(ref_dtype)
361366

362-
m_float8 = copy.deepcopy(m_ref)
363-
convert_to_float8_training(m_float8, config=config)
367+
# get gradient shape
368+
with torch.no_grad():
369+
_ = m_ref(input_tensor)
370+
grad_output = torch.ones_like(_)
371+
372+
m_lowp = copy.deepcopy(m_ref)
373+
if mx_recipe_name is None:
374+
convert_to_float8_training(m_lowp, config=config)
375+
else:
376+
swap_linear_with_mx_linear(m_lowp, config=config)
377+
378+
# this function is only used for cast_only
379+
to_mx_func = MXTensor.to_mx
380+
381+
print("m_ref", m_ref)
382+
print("m_lowp", m_lowp)
383+
print("input_tensor.shape", input_tensor.shape)
384+
print("grad_output.shape", grad_output.shape)
385+
print()
364386

365387
def ref_forw_backward(x):
388+
assert mode_filter != "cast_only", "unsupported"
366389
if enable_activation_checkpointing:
367390
out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn)
368391
else:
369392
out = m_ref(x)
370-
out.sum().backward()
393+
if mode_filter == "fwd_bwd":
394+
out.backward(grad_output)
395+
396+
def lowp_forw_backward_wrapper(x):
397+
if mode_filter == "cast_only":
398+
# just cast and return early
399+
_input_tensor_mx = to_mx_func(
400+
input_tensor,
401+
config.elem_dtype,
402+
config.block_size,
403+
gemm_kernel_choice=config.gemm_kernel_choice,
404+
)
405+
return
371406

372-
def float8_forw(x):
373407
if enable_activation_checkpointing:
374-
out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn)
408+
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
375409
else:
376-
out = m_float8(x)
377-
return out
378-
379-
def float8_forw_backward_wrapper(x):
380-
# TODO(future PR): this wrapper is for delayed scaling, we can clean it
381-
# up now that delayed scaling is deprecated.
382-
out = float8_forw(x)
383-
384-
# out.sum().backward() is also not torch.compile fullgraph
385-
# friendly
386-
with record_function("backward"):
387-
out.sum().backward()
410+
out = m_lowp(x)
411+
if mode_filter == "fwd_bwd":
412+
with record_function("backward"):
413+
out.backward(grad_output)
388414

389415
if compile:
390416
m_ref = torch.compile(m_ref, fullgraph=True)
391-
float8_forw = torch.compile(float8_forw, fullgraph=True)
417+
m_lowp = torch.compile(m_lowp, fullgraph=True)
418+
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
392419

393420
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
394421
# to populate triton kernel bandwidth further down in the script
@@ -398,15 +425,21 @@ def float8_forw_backward_wrapper(x):
398425
else:
399426
f = io.StringIO()
400427
context = redirect_stdout(f)
428+
429+
# if we are skipping forward, enable torch.no_grad()
430+
maybe_no_grad_context = (
431+
torch.no_grad() if mode_filter != "fwd_bwd" else nullcontext()
432+
)
433+
401434
try:
402-
with context:
435+
with context, maybe_no_grad_context:
403436
profile_iters = 5
404-
ref_times, float8_times = None, None
437+
ref_times, lowp_times = None, None
405438
data = []
406439

407440
num_leaf_tensors = 1 + len(list(m_ref.parameters()))
408441

409-
if dtype_filter != "float8":
442+
if experiment_filter != "lowp":
410443
# Profile Reference Model
411444
print("profiling ref")
412445
ref_trace_suffix = f"_{model_type}_ref_compile_{compile}.json"
@@ -452,50 +485,46 @@ def float8_forw_backward_wrapper(x):
452485
]
453486
)
454487

455-
if dtype_filter != "bfloat16":
456-
# Profile Float8 Model
457-
print("profiling float8")
458-
float8_trace_suffix = (
459-
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
460-
)
461-
float8_log_suffix = (
462-
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.txt"
463-
)
464-
trace_float8_path = profile_path_prefix + float8_trace_suffix
465-
log_float8_path = profile_path_prefix + float8_log_suffix
466-
trace_float8_modified_path = trace_float8_path.replace(
488+
if experiment_filter != "ref":
489+
# Profile lowp Model
490+
print("profiling lowp")
491+
lowp_trace_suffix = f"_{model_type}_lowp_compile_{compile}.json"
492+
lowp_log_suffix = f"_{model_type}_lowp_compile_{compile}.txt"
493+
trace_lowp_path = profile_path_prefix + lowp_trace_suffix
494+
log_lowp_path = profile_path_prefix + lowp_log_suffix
495+
trace_lowp_modified_path = trace_lowp_path.replace(
467496
".json", "_modified.json"
468497
)
469498
profile_config = ProfileConfig(
470-
trace_float8_path,
471-
log_float8_path,
472-
trace_float8_modified_path,
473-
float8_trace_suffix,
499+
trace_lowp_path,
500+
log_lowp_path,
501+
trace_lowp_modified_path,
502+
lowp_trace_suffix,
474503
iters=profile_iters,
475504
warmup_iters=2,
476505
sync=True,
477506
)
478507
p = profile_function(
479508
profile_config,
480-
float8_forw_backward_wrapper,
509+
lowp_forw_backward_wrapper,
481510
add_inductor_metadata_to_trace,
482511
input_tensor,
483512
)
484-
print(f"saved profiling trace to {trace_float8_path}")
513+
print(f"saved profiling trace to {trace_lowp_path}")
485514
if add_inductor_metadata_to_trace:
486-
print(f"saved torch logs to {log_float8_path}")
487-
print(f"saved modified trace to {trace_float8_modified_path}")
488-
float8_times = profiler_output_to_filtered_time_by_kernel_name(
515+
print(f"saved torch logs to {log_lowp_path}")
516+
print(f"saved modified trace to {trace_lowp_modified_path}")
517+
lowp_times = profiler_output_to_filtered_time_by_kernel_name(
489518
p, profile_iters, num_leaf_tensors
490519
)
491520
total_time_ms = (
492-
sum(v for v in float8_times.values()) / 1e3 / profile_iters
521+
sum(v for v in lowp_times.values()) / 1e3 / profile_iters
493522
)
494-
for k, v in float8_times.items():
523+
for k, v in lowp_times.items():
495524
v_ms = v / 1e3 / profile_iters
496525
data.append(
497526
[
498-
"1_float8",
527+
"1_lowp",
499528
k,
500529
kernel_name_to_category(k),
501530
v / 1e3 / profile_iters,
@@ -509,6 +538,7 @@ def float8_forw_backward_wrapper(x):
509538
# print the redirected stdout back to regular stdout
510539
print(f.getvalue())
511540

541+
# TODO(future PR): this seems to no longer work, fix it or delete it
512542
if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "":
513543
# populate the triton kernel bandwidth
514544
for line in f.getvalue().split("\n"):
@@ -546,13 +576,13 @@ def float8_forw_backward_wrapper(x):
546576
fill_value=0,
547577
margins=True,
548578
)
549-
# drop last row, which has totals across ref + float8 which does not make sense
579+
# drop last row, which has totals across ref + lowp which does not make sense
550580
df_p = df_p[:-1]
551581
df_p = df_p.transpose()
552582

553-
if dtype_filter == "both":
554-
df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"]
555-
df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"]
583+
if experiment_filter == "both":
584+
df_p["lowp_div_ref"] = df_p["1_lowp"] / df_p["0_ref"]
585+
df_p["ref_div_lowp"] = df_p["0_ref"] / df_p["1_lowp"]
556586

557587
print("\nSummary of time (ms) by kernel category\n\n", df_p)
558588

benchmarks/float8/utils.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,6 @@ def profiler_output_to_filtered_time_by_kernel_name(
7373
# forward pass sum
7474
assert e.count == num_iter, f"unexpected number of iter for {e.key}"
7575
continue
76-
elif e.key == "aten::fill_":
77-
# filling the forward pass sum with 1.0
78-
assert e.count == num_iter, f"unexpected number of iter for {e.key}"
79-
continue
80-
elif e.key == "aten::copy_":
81-
# copying 1.0 from grad_out of `sum` to grad_out of next op
82-
assert e.count == num_iter, f"unexpected number of iter for {e.key}"
83-
continue
8476
elif e.key == "aten::add_":
8577
# accumulating gradients into leaf tensors
8678
assert e.count == (
@@ -110,25 +102,16 @@ def profiler_output_to_gpu_time_for_key(prof, key):
110102

111103
def kernel_name_to_category(k):
112104
# number prefix is for easy sorting
113-
if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"):
114-
return "0_gemm"
115-
elif (
116-
# max(abs(tensor))
117-
("abs" in k and "max" in k)
118-
or
119-
# casting pointwise to float8
120-
("clamp" in k)
121-
or
122-
# things related to scaled_mm
123-
("scaled_mm" in k)
124-
or
125-
# syncing amaxes and scales
126-
("roll" in k)
105+
if k in (
106+
"aten::mm",
107+
"aten::addmm",
108+
"aten::_scaled_mm",
109+
"torchao::mx_fp8_bf16",
110+
"torchao::mx_fp4_bf16",
127111
):
128-
# note: the above filter is approximate and will give false
129-
# positives if model code contains other code to abs/max/clamp
130-
return "1_f8_overhead"
131-
return "2_other"
112+
return "0_gemm"
113+
else:
114+
return "1_other"
132115

133116

134117
def parse_bw_and_kernel_name(line):

0 commit comments

Comments
 (0)