Skip to content

Commit 25ddb77

Browse files
authored
Allow for scales to be in new e8m0 dtype (#1742)
stack-info: PR: #1742, branch: drisspg/stack/36
1 parent c72ebc6 commit 25ddb77

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

torchao/ops.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import functools
2+
13
import torch
24
from torch import Tensor
35

@@ -606,6 +608,27 @@ def _(
606608
return input_scale.new_empty(*input.shape[:-1], weight.shape[0])
607609

608610

611+
@functools.lru_cache()
612+
def _get_dtypes():
613+
"""TODO: when e8m0 is hardened and major release lets remove uint8 support"""
614+
if hasattr(torch, "float8_e8m0fnu"):
615+
return (torch.uint8, torch.float8_e8m0fnu)
616+
return (torch.uint8,)
617+
618+
619+
def _check_scale_dtypes(A_scale, B_scale):
620+
allowed_dtypes = _get_dtypes()
621+
622+
torch._check(
623+
A_scale.dtype in allowed_dtypes,
624+
lambda: f"A_scale tensor must be uint8 or float8_e8m0fnu, got {A_scale.dtype}",
625+
)
626+
torch._check(
627+
B_scale.dtype in allowed_dtypes,
628+
lambda: f"B_scale tensor must be uint8 or float8_e8m0fnu, got {B_scale.dtype}",
629+
)
630+
631+
609632
def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
610633
"""Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor.
611634
@@ -625,25 +648,7 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
625648
MXN bf16 Tensor
626649
627650
"""
628-
torch._check(
629-
A.dtype == torch.float8_e4m3fn,
630-
lambda: f"Input tensor A must be float8_e4m3fn, got {A.dtype}",
631-
)
632-
torch._check(
633-
B.dtype == torch.float8_e4m3fn,
634-
lambda: f"Input tensor B must be float8_e4m3fn, got {B.dtype}",
635-
)
636-
637-
# TODO - Once e8m0 dtype is added to core udpate
638-
# Check scale tensors are uint8
639-
torch._check(
640-
A_scale.dtype == torch.uint8,
641-
lambda: f"A_scale tensor must be uint8, got {A_scale.dtype}",
642-
)
643-
torch._check(
644-
B_scale.dtype == torch.uint8,
645-
lambda: f"B_scale tensor must be uint8, got {B_scale.dtype}",
646-
)
651+
_check_scale_dtypes(A_scale, B_scale)
647652
return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale)
648653

649654

@@ -674,6 +679,7 @@ def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
674679
MXN bf16 Tensor
675680
676681
"""
682+
_check_scale_dtypes(A_scale, B_scale)
677683
return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale)
678684

679685

0 commit comments

Comments
 (0)