1
+ import functools
2
+
1
3
import torch
2
4
from torch import Tensor
3
5
@@ -606,6 +608,27 @@ def _(
606
608
return input_scale .new_empty (* input .shape [:- 1 ], weight .shape [0 ])
607
609
608
610
611
+ @functools .lru_cache ()
612
+ def _get_dtypes ():
613
+ """TODO: when e8m0 is hardened and major release lets remove uint8 support"""
614
+ if hasattr (torch , "float8_e8m0fnu" ):
615
+ return (torch .uint8 , torch .float8_e8m0fnu )
616
+ return (torch .uint8 ,)
617
+
618
+
619
+ def _check_scale_dtypes (A_scale , B_scale ):
620
+ allowed_dtypes = _get_dtypes ()
621
+
622
+ torch ._check (
623
+ A_scale .dtype in allowed_dtypes ,
624
+ lambda : f"A_scale tensor must be uint8 or float8_e8m0fnu, got { A_scale .dtype } " ,
625
+ )
626
+ torch ._check (
627
+ B_scale .dtype in allowed_dtypes ,
628
+ lambda : f"B_scale tensor must be uint8 or float8_e8m0fnu, got { B_scale .dtype } " ,
629
+ )
630
+
631
+
609
632
def mx_fp8_bf16 (A : Tensor , B : Tensor , A_scale : Tensor , B_scale : Tensor ):
610
633
"""Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor.
611
634
@@ -625,25 +648,7 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
625
648
MXN bf16 Tensor
626
649
627
650
"""
628
- torch ._check (
629
- A .dtype == torch .float8_e4m3fn ,
630
- lambda : f"Input tensor A must be float8_e4m3fn, got { A .dtype } " ,
631
- )
632
- torch ._check (
633
- B .dtype == torch .float8_e4m3fn ,
634
- lambda : f"Input tensor B must be float8_e4m3fn, got { B .dtype } " ,
635
- )
636
-
637
- # TODO - Once e8m0 dtype is added to core udpate
638
- # Check scale tensors are uint8
639
- torch ._check (
640
- A_scale .dtype == torch .uint8 ,
641
- lambda : f"A_scale tensor must be uint8, got { A_scale .dtype } " ,
642
- )
643
- torch ._check (
644
- B_scale .dtype == torch .uint8 ,
645
- lambda : f"B_scale tensor must be uint8, got { B_scale .dtype } " ,
646
- )
651
+ _check_scale_dtypes (A_scale , B_scale )
647
652
return torch .ops .torchao .mx_fp8_bf16 .default (A , B , A_scale , B_scale )
648
653
649
654
@@ -674,6 +679,7 @@ def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
674
679
MXN bf16 Tensor
675
680
676
681
"""
682
+ _check_scale_dtypes (A_scale , B_scale )
677
683
return torch .ops .torchao .mx_fp4_bf16 .default (A , B , A_scale , B_scale )
678
684
679
685
0 commit comments