|
52 | 52 | make_packed_linear_int8_dynamic_activation_intx_weight_tensor,
|
53 | 53 | )
|
54 | 54 | from torchao.dtypes.utils import Layout
|
| 55 | +from torchao.float8.config import e4m3_dtype, e5m2_dtype |
55 | 56 | from torchao.float8.float8_linear import Float8Linear
|
56 | 57 | from torchao.float8.inference import Float8MMConfig
|
57 | 58 | from torchao.quantization.linear_activation_weight_observed_tensor import (
|
@@ -1396,7 +1397,7 @@ class Float8WeightOnlyConfig(AOBaseConfig):
|
1396 | 1397 | The actual matmul will be computed in original precision of the weight tensor.
|
1397 | 1398 | """
|
1398 | 1399 |
|
1399 |
| - weight_dtype: torch.dtype = torch.float8_e4m3fn |
| 1400 | + weight_dtype: torch.dtype = e4m3_dtype |
1400 | 1401 | set_inductor_config: bool = True
|
1401 | 1402 |
|
1402 | 1403 |
|
@@ -1569,8 +1570,8 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
|
1569 | 1570 |
|
1570 | 1571 | """
|
1571 | 1572 |
|
1572 |
| - activation_dtype: torch.dtype = torch.float8_e4m3fn |
1573 |
| - weight_dtype: torch.dtype = torch.float8_e4m3fn |
| 1573 | + activation_dtype: torch.dtype = e4m3_dtype |
| 1574 | + weight_dtype: torch.dtype = e4m3_dtype |
1574 | 1575 | granularity: Optional[
|
1575 | 1576 | Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
|
1576 | 1577 | ] = None
|
@@ -1660,8 +1661,8 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
|
1660 | 1661 | """
|
1661 | 1662 |
|
1662 | 1663 | layout: Layout = CutlassSemiSparseLayout()
|
1663 |
| - activation_dtype: torch.dtype = torch.float8_e5m2 |
1664 |
| - weight_dtype: torch.dtype = torch.float8_e4m3fn |
| 1664 | + activation_dtype: torch.dtype = e5m2_dtype |
| 1665 | + weight_dtype: torch.dtype = e4m3_dtype |
1665 | 1666 |
|
1666 | 1667 |
|
1667 | 1668 | @register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
|
@@ -1706,8 +1707,8 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
|
1706 | 1707 | """
|
1707 | 1708 |
|
1708 | 1709 | scale: torch.Tensor
|
1709 |
| - activation_dtype: torch.dtype = torch.float8_e4m3fn |
1710 |
| - weight_dtype: torch.dtype = torch.float8_e4m3fn |
| 1710 | + activation_dtype: torch.dtype = e4m3_dtype |
| 1711 | + weight_dtype: torch.dtype = e4m3_dtype |
1711 | 1712 | granularity: Optional[
|
1712 | 1713 | Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
|
1713 | 1714 | ] = None
|
|
0 commit comments