Skip to content

Commit 09ebb12

Browse files
authored
mx bench: add cast with to_blocked (#1771)
Update [ghstack-poisoned]
1 parent bac039f commit 09ebb12

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

benchmarks/float8/profile_lowp_training.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from torchao.prototype.mx_formats.config import MXLinearConfig
4949
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
5050
from torchao.prototype.mx_formats.mx_tensor import MXTensor
51+
from torchao.prototype.mx_formats.utils import to_blocked
5152

5253
# don't truncate long kernel names
5354
pd.options.display.max_colwidth = 100
@@ -298,11 +299,15 @@ def main(
298299
"lowp",
299300
"ref",
300301
), "experiment_filter must be one of `both`, `lowp`, `ref`"
301-
assert mode_filter in (
302-
"fwd_bwd",
303-
"fwd",
304-
"cast_only",
305-
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`"
302+
assert (
303+
mode_filter
304+
in (
305+
"fwd_bwd",
306+
"fwd",
307+
"cast_only",
308+
"cast_with_to_blocked",
309+
)
310+
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
306311
if mode_filter == "cast_only":
307312
assert experiment_filter == "lowp", "unsupported"
308313

@@ -378,14 +383,26 @@ def main(
378383
# this function is only used for cast_only
379384
to_mx_func = MXTensor.to_mx
380385

386+
# this function is used for cast_with_to_blocked
387+
def cast_with_to_blocked(x_hp):
388+
x_mx = MXTensor.to_mx(
389+
x_hp,
390+
config.elem_dtype,
391+
config.block_size,
392+
gemm_kernel_choice=config.gemm_kernel_choice,
393+
)
394+
m, k = x_hp.shape
395+
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
396+
return x_mx._data, scale_blocked
397+
381398
print("m_ref", m_ref)
382399
print("m_lowp", m_lowp)
383400
print("input_tensor.shape", input_tensor.shape)
384401
print("grad_output.shape", grad_output.shape)
385402
print()
386403

387404
def ref_forw_backward(x):
388-
assert mode_filter != "cast_only", "unsupported"
405+
assert mode_filter not in ("cast_only", "cast_with_to_blocked"), "unsupported"
389406
if enable_activation_checkpointing:
390407
out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn)
391408
else:
@@ -403,6 +420,9 @@ def lowp_forw_backward_wrapper(x):
403420
gemm_kernel_choice=config.gemm_kernel_choice,
404421
)
405422
return
423+
elif mode_filter == "cast_with_to_blocked":
424+
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
425+
return
406426

407427
if enable_activation_checkpointing:
408428
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
@@ -416,6 +436,7 @@ def lowp_forw_backward_wrapper(x):
416436
m_ref = torch.compile(m_ref, fullgraph=True)
417437
m_lowp = torch.compile(m_lowp, fullgraph=True)
418438
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
439+
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)
419440

420441
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
421442
# to populate triton kernel bandwidth further down in the script

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def mx_mm(aten_op, args, kwargs=None):
7474
# real MX gemm backed by torchao's CUTLASS kernels
7575
M, K, N = a.shape[0], a.shape[1], b.shape[1]
7676
assert b._data.t().is_contiguous()
77+
# TODO(future PR): use block_size instead of hardcoding 32
7778
a_scale = a._scale_e8m0.view(M, K // 32)
7879
b_scale = b._scale_e8m0.view(N, K // 32)
7980
a_scale_block = to_blocked(a_scale)

0 commit comments

Comments
 (0)