diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index 91083df0bf..a692f85de6 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -25,14 +25,14 @@ jobs: include: - name: SM-89 runs-on: linux.g6.4xlarge.experimental.nvidia.gpu - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126' + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu128' gpu-arch-type: "cuda" - gpu-arch-version: "12.6" + gpu-arch-version: "12.8" - name: H100 runs-on: linux.aws.h100 - torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126' + torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128' gpu-arch-type: "cuda" - gpu-arch-version: "12.4" + gpu-arch-version: "12.8" permissions: id-token: write contents: read diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index b63a406715..496764d39f 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -28,7 +28,10 @@ from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale from torchao.float8.float8_utils import compute_error from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_weight_only, quantize_, @@ -675,6 +678,100 @@ def test_preprocess_scale_3d_reshape(self): expected_shape = (8, 1) # Flattened (2*2*2, 1) self.assertEqual(result.shape, expected_shape) + def _get_weight_scale_and_impl(self, quantized_model, config_type): + """Helper to extract weight scale and impl based on config type""" + if config_type == "weight_only": + weight_impl = quantized_model.weight.tensor_impl + return weight_impl.scale.float(), weight_impl + else: + weight_impl = quantized_model.weight.original_weight_tensor.tensor_impl + return weight_impl.scale.float(), weight_impl + + def _verify_power_of_2_scales(self, scale): + """Helper to verify scales are powers of 2""" + log2_scale = torch.log2(scale) + is_power_of_2 = torch.allclose(log2_scale, torch.round(log2_scale), atol=1e-6) + self.assertTrue(is_power_of_2, "Scales should be powers of 2") + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize( + "config_factory,config_type,min_sm,min_error", + [ + ( + lambda: Float8WeightOnlyConfig(round_scales_to_power_of_2=True), + "weight_only", + 89, + 15.0, + ), + ( + lambda: Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), round_scales_to_power_of_2=True + ), + "dynamic", + 89, + 12.5, + ), + ( + lambda: Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), round_scales_to_power_of_2=True + ), + "dynamic", + 89, + 12.5, + ), + ( + lambda: Float8StaticActivationFloat8WeightConfig( + scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"), + granularity=PerTensor(), + round_scales_to_power_of_2=True, + ), + "static", + 89, + 15.0, + ), + ( + lambda: Float8DynamicActivationFloat8SemiSparseWeightConfig( + round_scales_to_power_of_2=True + ), + "dynamic", + 90, + 10.0, + ), + ], + ) + def test_power_of_2_scaling_configs( + self, config_factory, config_type, min_sm, min_error + ): + if min_sm == 90 and not is_sm_at_least_90(): + self.skipTest("Requires GPU with compute capability >= 9.0") + + device = "cuda" + dtype = torch.bfloat16 + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + + config = config_factory() + if isinstance( + config, Float8DynamicActivationFloat8SemiSparseWeightConfig + ) and not is_sm_version(9, 0): + self.skipTest("Float8SemiSparse requires compute capability == 9.0") + quantized_model = copy.deepcopy(model) + quantize_(quantized_model, config) + + scale, _ = self._get_weight_scale_and_impl(quantized_model, config_type) + self._verify_power_of_2_scales(scale) + + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) + with torch.no_grad(): + ref_output = model(input_tensor) + quant_output = quantized_model(input_tensor) + + self.assertEqual(ref_output.shape, quant_output.shape) + error = compute_error(ref_output, quant_output) + self.assertGreater(error, min_error, f"Quantization SQNR too low: {error}") + common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 39f9131a9e..ee67b42c4a 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -457,13 +457,17 @@ def from_hp_to_floatx( target_dtype: torch.dtype, _layout: Layout, scale_dtype: Optional[torch.dtype] = None, + round_scales_to_power_of_2: bool = False, ): """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: original_shape = input_float.shape input_float = _layout.pre_process(input_float) scale = _choose_qparams_affine_float8( - input_float, float8_dtype=target_dtype, block_size=block_size + input_float, + float8_dtype=target_dtype, + block_size=block_size, + round_scales_to_power_of_2=round_scales_to_power_of_2, ) data = _quantize_affine_float8(input_float, scale, target_dtype) data, scale, zero_point = _layout.post_process( @@ -530,6 +534,7 @@ def from_hp_to_fpx( cls, input_float: torch.Tensor, _layout: Layout, + round_scales_to_power_of_2: bool = False, ): """Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7.""" from torchao.dtypes.floatx import FloatxTensorCoreLayout @@ -545,7 +550,9 @@ def from_hp_to_fpx( ebits, mbits = _layout.ebits, _layout.mbits # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = _choose_qparams_affine_floatx(input_float, ebits, mbits) + scale = _choose_qparams_affine_floatx( + input_float, ebits, mbits, round_scales_to_power_of_2 + ) floatx_unpacked = _quantize_affine_floatx(input_float, scale, ebits, mbits) floatx_packed, scale, _ = _layout.post_process( floatx_unpacked, scale, None, block_size diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index c7fb1e1a7c..8b68164fd2 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -22,6 +22,7 @@ AQTTensorImpl, Layout, ) +from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, @@ -214,7 +215,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: def to_scaled_tc_floatx( - tensor: Tensor, ebits: int, mbits: int + tensor: Tensor, ebits: int, mbits: int, round_scales_to_power_of_2: bool = False ) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 @@ -230,6 +231,8 @@ def to_scaled_tc_floatx( dtype = tensor.dtype tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + if round_scales_to_power_of_2: + scale = _round_scale_down_to_power_of_2(scale.float()) tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) return tensor_tc_floatx, scale.to(dtype) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 625fb29235..9de357439d 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -236,6 +236,7 @@ def pad_tensor_for_matmul( return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) -def _round_scale_down_to_power_of_2(scale: torch.Tensor): +def _round_scale_down_to_power_of_2(scale: torch.Tensor) -> torch.Tensor: + """Rounds the scale down to the nearest power of 2.""" assert scale.dtype == torch.float32, "scale must be float32 tensor" return torch.exp2(torch.floor(torch.log2(scale))) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7b40f388ed..c21d71a3b6 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1281,6 +1281,7 @@ def _int4_symm_cutlass_quant(x: torch.Tensor) -> torch.Tensor: def _float8_cutlass_quant( x: torch.Tensor, target_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, ) -> torch.Tensor: return to_affine_quantized_floatx( x, @@ -1288,12 +1289,14 @@ def _float8_cutlass_quant( scale_dtype=torch.float32, target_dtype=target_dtype, _layout=Float8Layout(mm_config=None), + round_scales_to_power_of_2=round_scales_to_power_of_2, ) def _float8_cutlass_quant_sparse( x: torch.Tensor, target_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, ) -> (torch.Tensor, torch.Tensor): return to_affine_quantized_floatx( x, @@ -1301,6 +1304,7 @@ def _float8_cutlass_quant_sparse( scale_dtype=torch.float32, target_dtype=target_dtype, _layout=CutlassSemiSparseLayout(), + round_scales_to_power_of_2=round_scales_to_power_of_2, ) @@ -1410,6 +1414,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2. Note: The actual matmul will be computed in original precision of the weight tensor. @@ -1417,6 +1422,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True + round_scales_to_power_of_2: bool = False # for BC @@ -1433,6 +1439,7 @@ def _float8_weight_only_quant_tensor(weight, config): target_dtype=config.weight_dtype, scale_dtype=None, _layout=Float8Layout(mm_config=None), + round_scales_to_power_of_2=config.round_scales_to_power_of_2, ) return new_weight @@ -1461,6 +1468,7 @@ def _input_activation_quant_func_fp8( activation_dtype: torch.dtype, scale: Optional[torch.Tensor] = None, zero_point: Optional[torch.Tensor] = None, + round_scales_to_power_of_2: bool = False, ): """This function is used to quantize the input activation tensor for an aqt_float variant. If scale is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. @@ -1481,6 +1489,7 @@ def _input_activation_quant_func_fp8( target_dtype=activation_dtype, scale_dtype=torch.float32, _layout=Float8Layout(mm_config=None), # Config is stored on weight + round_scales_to_power_of_2=round_scales_to_power_of_2, ) else: assert isinstance(activation_granularity, PerTensor), ( @@ -1538,6 +1547,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): only PerTensor and PerRow are supported. mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2. """ @@ -1546,6 +1556,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True + round_scales_to_power_of_2: bool = False def __post_init__(self): if self.mm_config is None: @@ -1589,12 +1600,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): target_dtype=weight_dtype, scale_dtype=torch.float32, _layout=Float8Layout(mm_config=mm_config), + round_scales_to_power_of_2=config.round_scales_to_power_of_2, ) input_quant_func = _input_activation_quant_func_fp8 input_quant_kwargs = { "activation_granularity": activation_granularity, "activation_dtype": activation_dtype, + "round_scales_to_power_of_2": config.round_scales_to_power_of_2, } quantized_weight = to_linear_activation_quantized( @@ -1634,11 +1647,13 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig): `layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment. `activation_dtype`: data type for quantized activation tensor. `weight_dtype`: data type for quantized weight tensor. + `round_scales_to_power_of_2`: If True, round scaling factors down to the nearest power of 2. """ layout: Layout = CutlassSemiSparseLayout() activation_dtype: torch.dtype = e5m2_dtype weight_dtype: torch.dtype = e4m3_dtype + round_scales_to_power_of_2: bool = False @register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig) @@ -1657,11 +1672,16 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform( f"Only CutlassSemiSparseLayout layout is supported. Received {layout}." ) - weight = _float8_cutlass_quant_sparse(weight, weight_dtype) + weight = _float8_cutlass_quant_sparse( + weight, weight_dtype, config.round_scales_to_power_of_2 + ) weight = to_linear_activation_quantized( weight, _float8_cutlass_quant, - quant_kwargs={"target_dtype": activation_dtype}, + quant_kwargs={ + "target_dtype": activation_dtype, + "round_scales_to_power_of_2": config.round_scales_to_power_of_2, + }, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) @@ -1680,6 +1700,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + round_scales_to_power_of_2 (bool): If True, round scaling factors down to the nearest power of 2. """ scale: torch.Tensor @@ -1690,6 +1711,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): ] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True + round_scales_to_power_of_2: bool = False def __post_init__(self): if self.mm_config is None: @@ -1733,12 +1755,14 @@ def _float8_static_activation_float8_weight_transform( target_dtype=weight_dtype, scale_dtype=torch.float32, _layout=Float8Layout(mm_config=mm_config), + round_scales_to_power_of_2=config.round_scales_to_power_of_2, ) input_quant_func = _input_activation_quant_func_fp8 input_quant_kwargs = { "activation_granularity": activation_granularity, "activation_dtype": activation_dtype, + "round_scales_to_power_of_2": config.round_scales_to_power_of_2, } quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata( diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index df136bc06e..6c53683f9a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -10,6 +10,7 @@ import torch +from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, @@ -2119,7 +2120,10 @@ def _choose_qparams_and_quantize_affine_hqq( def _choose_qparams_affine_floatx( - tensor: torch.Tensor, ebits: int, mbits: int + tensor: torch.Tensor, + ebits: int, + mbits: int, + round_scales_to_power_of_2: bool = False, ) -> torch.Tensor: """Choose quantization parameters for floatx quantization. @@ -2143,14 +2147,18 @@ def _choose_qparams_affine_floatx( # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) # workaround: global lookup table + dtype = tensor.dtype exp_bias = _ONES_TABLE[ebits - 1] max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( _ONES_TABLE[mbits + 1] / (2**mbits) ) - dtype = tensor.dtype tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + + if round_scales_to_power_of_2: + scale = _round_scale_down_to_power_of_2(scale.float()) + return scale.to(dtype) @@ -2183,6 +2191,7 @@ def _choose_qparams_affine_float8( float8_dtype: torch.dtype = torch.float8_e4m3fn, scale_dtype: torch.dtype = torch.float32, block_size: Optional[Tuple[int, ...]] = None, + round_scales_to_power_of_2: bool = False, ) -> torch.Tensor: """ Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. @@ -2192,6 +2201,7 @@ def _choose_qparams_affine_float8( float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). scale_dtype (torch.dtype): Data type of the scaling factor (e.g., torch.float32). block_size (Optional[Tuple[int, ...]]): Block size for block-wise quantization. If None, tensorwise quantization is used. + round_scales_to_power_of_2 (bool): Whether to round scales down to the nearest power of 2. """ quant_max = torch.finfo(float8_dtype).max # only tensorwise scaling is supported for now: @@ -2213,11 +2223,10 @@ def _choose_qparams_affine_float8( ] scale = scale.reshape(output_shape) - if scale_dtype is not torch.float32: - # Shielding for Version > 2.8 - assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported" - scale = torch.exp2(_Round.apply(torch.log2(scale))) - return scale.to(dtype=torch.float32) + if round_scales_to_power_of_2: + scale = _round_scale_down_to_power_of_2(scale.float()) + + return scale.to(dtype=scale_dtype) def _expand_scale_to_tensor_shape(