Skip to content

Commit 212d912

Browse files
authored
use correct fp8 quantization dtype for AMD GPU
Differential Revision: D75021458 Pull Request resolved: #2225
1 parent 1bbeed1 commit 212d912

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

test/float8/test_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
tensor_to_scale,
5757
)
5858
from torchao.testing.float8.test_utils import get_test_float8_linear_config
59+
from torchao.utils import is_MI300, is_ROCM
5960

6061
random.seed(0)
6162
torch.manual_seed(0)
@@ -271,6 +272,15 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
271272
sqnr = compute_error(c_ref, c_fp8_compute)
272273
assert sqnr >= 25.0
273274

275+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
276+
def test_fp8_dtype(
277+
self,
278+
):
279+
if is_ROCM() and is_MI300():
280+
assert e4m3_dtype == torch.float8_e4m3fnuz
281+
else:
282+
assert e4m3_dtype == torch.float8_e4m3fn
283+
274284

275285
class TestFloat8Linear:
276286
def _test_linear_impl(

torchao/quantization/quant_api.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
make_packed_linear_int8_dynamic_activation_intx_weight_tensor,
5353
)
5454
from torchao.dtypes.utils import Layout
55+
from torchao.float8.config import e4m3_dtype, e5m2_dtype
5556
from torchao.float8.float8_linear import Float8Linear
5657
from torchao.float8.inference import Float8MMConfig
5758
from torchao.quantization.linear_activation_weight_observed_tensor import (
@@ -1396,7 +1397,7 @@ class Float8WeightOnlyConfig(AOBaseConfig):
13961397
The actual matmul will be computed in original precision of the weight tensor.
13971398
"""
13981399

1399-
weight_dtype: torch.dtype = torch.float8_e4m3fn
1400+
weight_dtype: torch.dtype = e4m3_dtype
14001401
set_inductor_config: bool = True
14011402

14021403

@@ -1569,8 +1570,8 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15691570
15701571
"""
15711572

1572-
activation_dtype: torch.dtype = torch.float8_e4m3fn
1573-
weight_dtype: torch.dtype = torch.float8_e4m3fn
1573+
activation_dtype: torch.dtype = e4m3_dtype
1574+
weight_dtype: torch.dtype = e4m3_dtype
15741575
granularity: Optional[
15751576
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
15761577
] = None
@@ -1660,8 +1661,8 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
16601661
"""
16611662

16621663
layout: Layout = CutlassSemiSparseLayout()
1663-
activation_dtype: torch.dtype = torch.float8_e5m2
1664-
weight_dtype: torch.dtype = torch.float8_e4m3fn
1664+
activation_dtype: torch.dtype = e5m2_dtype
1665+
weight_dtype: torch.dtype = e4m3_dtype
16651666

16661667

16671668
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
@@ -1706,8 +1707,8 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
17061707
"""
17071708

17081709
scale: torch.Tensor
1709-
activation_dtype: torch.dtype = torch.float8_e4m3fn
1710-
weight_dtype: torch.dtype = torch.float8_e4m3fn
1710+
activation_dtype: torch.dtype = e4m3_dtype
1711+
weight_dtype: torch.dtype = e4m3_dtype
17111712
granularity: Optional[
17121713
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
17131714
] = None

0 commit comments

Comments
 (0)