48
48
from torchao .prototype .mx_formats .config import MXLinearConfig
49
49
from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
50
50
from torchao .prototype .mx_formats .mx_tensor import MXTensor
51
+ from torchao .prototype .mx_formats .utils import to_blocked
51
52
52
53
# don't truncate long kernel names
53
54
pd .options .display .max_colwidth = 100
@@ -298,11 +299,15 @@ def main(
298
299
"lowp" ,
299
300
"ref" ,
300
301
), "experiment_filter must be one of `both`, `lowp`, `ref`"
301
- assert mode_filter in (
302
- "fwd_bwd" ,
303
- "fwd" ,
304
- "cast_only" ,
305
- ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`"
302
+ assert (
303
+ mode_filter
304
+ in (
305
+ "fwd_bwd" ,
306
+ "fwd" ,
307
+ "cast_only" ,
308
+ "cast_with_to_blocked" ,
309
+ )
310
+ ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
306
311
if mode_filter == "cast_only" :
307
312
assert experiment_filter == "lowp" , "unsupported"
308
313
@@ -378,14 +383,26 @@ def main(
378
383
# this function is only used for cast_only
379
384
to_mx_func = MXTensor .to_mx
380
385
386
+ # this function is used for cast_with_to_blocked
387
+ def cast_with_to_blocked (x_hp ):
388
+ x_mx = MXTensor .to_mx (
389
+ x_hp ,
390
+ config .elem_dtype ,
391
+ config .block_size ,
392
+ gemm_kernel_choice = config .gemm_kernel_choice ,
393
+ )
394
+ m , k = x_hp .shape
395
+ scale_blocked = to_blocked (x_mx ._scale_e8m0 .reshape (m , k // config .block_size ))
396
+ return x_mx ._data , scale_blocked
397
+
381
398
print ("m_ref" , m_ref )
382
399
print ("m_lowp" , m_lowp )
383
400
print ("input_tensor.shape" , input_tensor .shape )
384
401
print ("grad_output.shape" , grad_output .shape )
385
402
print ()
386
403
387
404
def ref_forw_backward (x ):
388
- assert mode_filter != "cast_only" , "unsupported"
405
+ assert mode_filter not in ( "cast_only" , "cast_with_to_blocked" ) , "unsupported"
389
406
if enable_activation_checkpointing :
390
407
out = checkpoint (m_ref , x , use_reentrant = False , context_fn = context_fn )
391
408
else :
@@ -403,6 +420,9 @@ def lowp_forw_backward_wrapper(x):
403
420
gemm_kernel_choice = config .gemm_kernel_choice ,
404
421
)
405
422
return
423
+ elif mode_filter == "cast_with_to_blocked" :
424
+ _input_tensor_mx , scale = cast_with_to_blocked (input_tensor )
425
+ return
406
426
407
427
if enable_activation_checkpointing :
408
428
out = checkpoint (m_lowp , x , use_reentrant = False , context_fn = context_fn )
@@ -416,6 +436,7 @@ def lowp_forw_backward_wrapper(x):
416
436
m_ref = torch .compile (m_ref , fullgraph = True )
417
437
m_lowp = torch .compile (m_lowp , fullgraph = True )
418
438
to_mx_func = torch .compile (to_mx_func , fullgraph = True )
439
+ cast_with_to_blocked = torch .compile (cast_with_to_blocked , fullgraph = True )
419
440
420
441
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
421
442
# to populate triton kernel bandwidth further down in the script
0 commit comments