Skip to content

Commit 0aa89a8

Browse files
authored
Register choose_qparams_affine_float8 as custom op (#2461)
1 parent dc87bca commit 0aa89a8

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def test_mm_float8dq_per_row(
356356
)
357357
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
358358
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
359-
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
359+
@common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
360360
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
361361
"""Test _dequantize_affine_float8 with various configurations"""
362362

test/integration/test_integration.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from parameterized import parameterized
1818
from torch._dynamo import config
1919
from torch._inductor.utils import run_and_get_code
20+
from torch.testing import FileCheck
2021

2122
import torchao
2223
from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout
@@ -37,6 +38,7 @@
3738

3839
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
3940
from torchao.quantization.quant_api import (
41+
Float8DynamicActivationFloat8WeightConfig,
4042
_replace_with_custom_fn_if_matches_filter,
4143
change_linear_weights_to_int4_woqtensors,
4244
change_linear_weights_to_int8_dqtensors,
@@ -86,6 +88,7 @@
8688
check_cpu_version,
8789
check_xpu_version,
8890
is_fbcode,
91+
is_sm_at_least_89,
8992
is_sm_at_least_90,
9093
unwrap_tensor_subclass,
9194
)
@@ -2077,6 +2080,31 @@ def forward(self, x):
20772080
self.assertTrue(torch.ops.torchao.quantize_affine.default in targets)
20782081
self.assertFalse(torch.ops.aten.narrow.default in targets)
20792082

2083+
@unittest.skipIf(
2084+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
2085+
)
2086+
def test_export_float8(self):
2087+
class SimpleNetwork(torch.nn.Module):
2088+
def __init__(self):
2089+
super(SimpleNetwork, self).__init__()
2090+
self.linear = torch.nn.Linear(
2091+
in_features=32, out_features=16, bias=False
2092+
)
2093+
2094+
def forward(self, x):
2095+
return self.linear(x)
2096+
2097+
model = SimpleNetwork().eval().cuda()
2098+
inp = torch.randn(2, 32).cuda()
2099+
config = Float8DynamicActivationFloat8WeightConfig()
2100+
quantize_(model, config)
2101+
2102+
ep = torch.export.export(model, (inp,))
2103+
print(ep)
2104+
FileCheck().check_count(
2105+
"torch.ops.torchao.choose_qparams_affine_float8.default", 1, exactly=True
2106+
).run(str(ep.graph))
2107+
20802108

20812109
class TestUtils(unittest.TestCase):
20822110
@parameterized.expand(

torchao/quantization/quant_primitives.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,11 +2178,12 @@ def _dequantize_affine_floatx(
21782178
return tensor
21792179

21802180

2181+
@register_custom_op
21812182
def _choose_qparams_affine_float8(
21822183
tensor: torch.Tensor,
2184+
block_size: List[int],
21832185
float8_dtype: torch.dtype = torch.float8_e4m3fn,
21842186
scale_dtype: torch.dtype = torch.float32,
2185-
block_size: Optional[Tuple[int, ...]] = None,
21862187
) -> torch.Tensor:
21872188
"""
21882189
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
@@ -2195,7 +2196,7 @@ def _choose_qparams_affine_float8(
21952196
"""
21962197
quant_max = torch.finfo(float8_dtype).max
21972198
# only tensorwise scaling is supported for now:
2198-
if block_size is None:
2199+
if len(block_size) == 0:
21992200
max_abs = tensor.abs().max()
22002201
scale = max_abs / quant_max
22012202
else:

0 commit comments

Comments
 (0)