From d0b71cc3854fd022b06f024faad97bfdba77a1cd Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 23 May 2025 11:32:04 -0700 Subject: [PATCH] Fix Per Row scaling for inference stack-info: PR: https://github.com/pytorch/ao/pull/2253, branch: drisspg/stack/56 --- test/dtypes/test_affine_quantized_float.py | 308 ++++++++++++++++++++- torchao/dtypes/affine_quantized_tensor.py | 7 +- torchao/dtypes/floatx/float8_layout.py | 237 ++++++++++------ torchao/float8/inference.py | 74 ++++- torchao/quantization/quant_api.py | 69 ++--- torchao/quantization/quant_primitives.py | 112 ++++++-- 6 files changed, 628 insertions(+), 179 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 8c36f5ac7a..408e6e6ce0 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -25,6 +25,7 @@ from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl from torchao.float8.float8_utils import compute_error from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, @@ -42,6 +43,9 @@ from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, + choose_qparams_affine_float8, + dequantize_affine_float8, + quantize_affine_float8, ) from torchao.utils import ( is_sm_at_least_89, @@ -297,21 +301,299 @@ def test_fp8_weight_dimension_warning(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - def test_mm_float8dq(self): + @common_utils.parametrize( + "in_features,out_features", [(512, 1024), (256, 768), (1024, 512)] + ) + @common_utils.parametrize( + "leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)] + ) # fmt: skip + @common_utils.parametrize("bias", [True, False]) + def test_mm_float8dq_per_row( + self, in_features, out_features, leading_shape, bias: bool + ): + device = "cuda" + dtype = torch.bfloat16 + input_shape = leading_shape + (in_features,) + + ref_linear = ( + torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype) + ) + test_linear = copy.deepcopy(ref_linear) + quantize_( + test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + ) + + quant_weight = test_linear.weight + + self.assertTrue(hasattr(quant_weight, "original_weight_tensor")) + weight_impl = quant_weight.original_weight_tensor.tensor_impl + + self.assertTrue(hasattr(weight_impl, "float8_data")) + self.assertTrue(hasattr(weight_impl, "scale")) + self.assertFalse(weight_impl.transposed) + + # Verify scale shape for row-wise quantization + expected_scale_shape = (out_features, 1) + actual_scale_shape = weight_impl.scale.shape + self.assertEqual(actual_scale_shape, expected_scale_shape) + + self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features)) + + input_tensor = torch.randn(*input_shape, device=device, dtype=dtype) + + with torch.no_grad(): + ref_output = ref_linear(input_tensor) + quant_output = torch.nn.functional.linear(input_tensor, quant_weight) + + expected_output_shape = input_tensor.shape[:-1] + (out_features,) + self.assertEqual(quant_output.shape, expected_output_shape) + + error = compute_error(ref_output, quant_output) + assert error > 20, f"Quantization error is too high got a SQNR of {error}" + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) + @common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)]) + def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): + """Test dequantize_affine_float8 with various configurations""" + + device = "cuda" + input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) + + # Choose quantization parameters + scale = choose_qparams_affine_float8( + input_tensor, float8_dtype=float8_dtype, block_size=block_size + ) + + # Quantize + quantized = quantize_affine_float8(input_tensor, scale, float8_dtype) + + # Dequantize + dequantized = dequantize_affine_float8(quantized, scale, output_dtype) + + # Verify output properties + self.assertEqual(dequantized.dtype, output_dtype) + self.assertEqual(dequantized.shape, input_tensor.shape) + self.assertEqual(dequantized.device, input_tensor.device) + + # Verify quantization/dequantization roundtrip is reasonable + error = torch.abs(input_tensor.to(output_dtype) - dequantized).mean() + self.assertLess(error, 0.1, "Quantization error too high") + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + def test_dequantize_affine_float8_scale_broadcasting(self): + """Test that scale broadcasting works correctly for block-wise quantization""" + device = "cuda" + # Create input tensor with known block structure + input_tensor = torch.randn(4, 32, device=device, dtype=torch.float32) + block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim + + # Choose quantization parameters + scale = choose_qparams_affine_float8( + input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # Verify scale shape + expected_scale_shape = ( + input_tensor.shape[0] // block_size[0], + input_tensor.shape[1] // block_size[1], + ) + self.assertEqual(scale.shape, expected_scale_shape) + + # Quantize + quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn) + + # Dequantize + dequantized = dequantize_affine_float8(quantized, scale, torch.float32) + + # Verify shapes match + self.assertEqual(dequantized.shape, input_tensor.shape) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize( + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] + ) + def test_float8_tensor_slicing_basic(self, granularity): + """Test basic slicing operations on Float8 tensors""" device = "cuda" dtype = torch.bfloat16 - weight = torch.randn(512, 1024).to(device).to(dtype) - weight = weight.t() - - l = torch.nn.Linear(512, 1024).to(device).to(dtype) - l.weight = torch.nn.Parameter(weight) - quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) - # weight shape: 1024 x 512 - weight = l.weight - - input = torch.randn(1, 512, device=device, dtype=dtype) - # make sure it runs - torch.nn.functional.linear(input, weight) + + # Create and quantize a model + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + quantize_( + model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + ) + + weight_impl = model.weight.original_weight_tensor.tensor_impl + + # Test dimension 0 slicing (rows) + sliced_0 = weight_impl[10:20] + self.assertEqual(sliced_0.shape, (10, 64)) + + # Test dimension 1 slicing (columns) + sliced_1 = weight_impl[:, 20:40] + self.assertEqual(sliced_1.shape, (32, 20)) + + # Test combined slicing + sliced_both = weight_impl[5:15, 10:30] + self.assertEqual(sliced_both.shape, (10, 20)) + + # Verify the sliced tensors are still Float8 tensors + self.assertTrue(isinstance(sliced_0, Float8AQTTensorImpl)) + self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl)) + self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl)) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + def test_float8_tensor_slicing_per_tensor(self): + """Test slicing with per-tensor quantization (scale should not change)""" + device = "cuda" + dtype = torch.bfloat16 + + # Create and quantize with per-tensor granularity + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + quantize_( + model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) + ) + + original_weight = model.weight + original_impl = original_weight.original_weight_tensor.tensor_impl + original_scale = original_impl.scale + + # Test slicing + sliced_weight = original_weight[10:20, 20:40] + sliced_impl = sliced_weight.original_weight_tensor.tensor_impl + + # For per-tensor quantization, scale should be identical + self.assertTrue(torch.equal(original_scale, sliced_impl.scale)) + self.assertEqual(sliced_impl.scale.numel(), 1) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @unittest.skipIf( + not is_sm_at_least_90(), + "Per-row quantization requires compute capability >= 9.0", + ) + def test_float8_tensor_slicing_per_row(self): + """Test slicing with per-row quantization (scale should be sliced appropriately)""" + device = "cuda" + dtype = torch.bfloat16 + + # Create and quantize with per-row granularity + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + quantize_( + model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + ) + + original_weight = model.weight # Shape: (32, 64) + original_impl = original_weight.original_weight_tensor.tensor_impl + original_scale = original_impl.scale # Shape: (32, 1) + + # Test row slicing (dimension 0) + sliced_rows = original_weight[10:20] # Shape: (10, 64) + sliced_impl = sliced_rows.original_weight_tensor.tensor_impl + + # Scale should be sliced to match the rows + expected_scale_shape = (10, 1) + self.assertEqual(sliced_impl.scale.shape, expected_scale_shape) + + # Verify the scale values are correct (should be subset of original) + self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20])) + + # Test column slicing (dimension 1) - scale should not change for per-row + sliced_cols = original_weight[:, 20:40] # Shape: (32, 20) + sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl + + # Scale shape should remain the same since we're not changing rows + self.assertEqual(sliced_cols_impl.scale.shape, (32, 1)) + self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale)) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + def test_float8_tensor_slicing_edge_cases(self): + """Test edge cases in slicing""" + device = "cuda" + dtype = torch.bfloat16 + + # Create and quantize a model + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + quantize_( + model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) + ) + + original_weight = model.weight + + # Test empty slice + empty_slice = original_weight[0:0] + self.assertEqual(empty_slice.shape, (0, 64)) + + # Test single element slice + single_row = original_weight[0:1] + self.assertEqual(single_row.shape, (1, 64)) + + # Test out of bounds (should be handled by PyTorch) + large_slice = original_weight[:100] # More than available rows + self.assertEqual(large_slice.shape, (32, 64)) # Should clamp to available + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize( + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] + ) + def test_float8_tensor_slicing_functional_correctness(self, granularity): + """Test that sliced tensors produce correct results in computations""" + device = "cuda" + dtype = torch.bfloat16 + + # Create reference and quantized models with dimensions that are multiples of 16 + ref_model = ( + torch.nn.Linear(64, 48, bias=False).to(device).to(dtype) + ) # 48 is divisible by 16 + quant_model = copy.deepcopy(ref_model) + quantize_( + quant_model, + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), + ) + + # Create input with batch size that works well with slicing + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) + + ref_weight_slice = ref_model.weight[0:16, 0:32] + quant_weight_slice = quant_model.weight[0:16, 0:32] + + input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight + + # Compute with sliced weights + with torch.no_grad(): + ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice) + quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice) + + # Verify shapes + expected_shape = (8, 16) # batch_size x out_features_sliced + self.assertEqual(ref_output.shape, expected_shape) + self.assertEqual(quant_output.shape, expected_shape) + + # Verify reasonable quantization error + error = compute_error(ref_output, quant_output) + self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 65649593a3..6cb2e8997e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -462,10 +462,10 @@ def from_hp_to_floatx( if target_dtype in FP8_TYPES: original_shape = input_float.shape input_float = _layout.pre_process(input_float) - - scale = choose_qparams_affine_float8(input_float, float8_dtype=target_dtype) + scale = choose_qparams_affine_float8( + input_float, float8_dtype=target_dtype, block_size=block_size + ) data = quantize_affine_float8(input_float, scale, target_dtype) - data, scale, zero_point = _layout.post_process( data, scale, None, block_size ) @@ -503,7 +503,6 @@ def from_hp_to_floatx_static( input_float, scale, target_dtype, - scale_dtype, ) data, scale, zero_point = _layout.post_process( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 5914f00102..799832a5ea 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.utils._python_dispatch import ( @@ -26,6 +26,18 @@ from torchao.utils import _is_float8_type, fill_defaults aten = torch.ops.aten +FLOAT8_IMPL_OPS_TABLE: Dict[Any, Any] = {} + + +def implements(aten_ops: List[Any]): + """Register aten ops to the float8 op table""" + + def decorator(func): + for op in aten_ops: + FLOAT8_IMPL_OPS_TABLE[op] = func + return func + + return decorator def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> bool: @@ -33,7 +45,6 @@ def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> b transposed_match = (self.transposed == src.transposed) or ( self.transposed is False and src.transposed is None ) - return ( isinstance(self, Float8AQTTensorImpl) and isinstance(src, Float8AQTTensorImpl) @@ -160,90 +171,23 @@ def __tensor_unflatten__( def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - elif func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - elif func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - args[0].transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, args[0]) - elif func is aten.copy_.default: - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" - ) - elif func in [aten.select.int, aten.index.Tensor]: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), - ) - elif func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - # TODO: scale replecation should be dependent on block size - if self.scale.ndim == 1: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: aten.slice.Tensor(x, dim, start, end, step) - ), - ) - elif self.scale.ndim == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - Float8AQTTensorImpl( - aten.slice.Tensor(self.float8_data, dim, start, end, step), - self.scale, - None, - self._layout, - ), - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" - ) - elif dim == 1: - return return_and_correct_aliasing( - func, - args, - kwargs, - Float8AQTTensorImpl( - aten.slice.Tensor( - self.float8_data, dim, start, end, step - ).contiguous(), - self.scale, - None, - self._layout, - ), - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + def allowed_subclasses(type): + return ( + issubclass(cls, type) + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) + or issubclass( + torch._subclasses.functional_tensor.FunctionalTensor, type ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) + if not all(allowed_subclasses(t) for t in types): + return NotImplemented + + if func in FLOAT8_IMPL_OPS_TABLE: + return FLOAT8_IMPL_OPS_TABLE[func](func, types, args, kwargs) + + raise NotImplementedError(f"attempting to run {func}, this is not supported") + __torch_function__ = torch._C._disabled_torch_function_impl def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: @@ -281,6 +225,130 @@ def __repr__(self): ) +########################## +# Regsiter FP8 Ops +########################## + + +@implements([aten.detach.default, aten.alias.default, aten.clone.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(func) + ) + + +@implements([aten.t.default]) +def _(func, types, args, kwargs): + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + args[0].transposed = not args[0].transposed + return return_and_correct_aliasing(func, args, kwargs, args[0]) + + +@implements([aten.copy_.default]) +def _(func, types, args, kwargs): + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + +@implements([aten.select.int, aten.index.Tensor]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + + +@implements([aten.slice.Tensor]) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + + # Always slice the float8_data + sliced_data = aten.slice.Tensor(self.float8_data, dim, start, end, step) + + if self.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = self.scale + else: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = _slice_scale_for_dimension( + self.scale, self.float8_data.shape, dim, start, end, step + ) + + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8AQTTensorImpl( + sliced_data, + sliced_scale, + self.transposed, + self._layout, + ), + ) + + +def _slice_scale_for_dimension( + scale: torch.Tensor, + data_shape: List[int], + dim: int, + start: int, + end: int, + step: int, +) -> torch.Tensor: + """ + Slice the scale tensor appropriately based on the data tensor slicing. + + This function calculates how the scale should be sliced when the data tensor + is sliced along a given dimension, taking into account the block structure. + """ + # Unsupported case for now, this would be 1 scale per data element + if scale.shape == data_shape: + return aten.slice.Tensor(scale, dim, start, end, step) + + # Reconstruct block sizes based on data shape and scale shape + block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape))) + + if dim >= len(block_sizes): + # Slicing beyond the dimensions we care about + return scale + + block_size_for_dim = block_sizes[dim] + + if block_size_for_dim == 1: + # Scale is per-element along this dimension + # Slice away as normal + return aten.slice.Tensor(scale, dim, start, end, step) + else: + # There is blocking in this dimension + # Calculate which scale elements correspond to the sliced data + scale_start = start // block_size_for_dim if start is not None else None + scale_end = ( + (end + block_size_for_dim - 1) // block_size_for_dim + if end is not None + else None + ) + + # Error on Step > 1 + if step > 1: + raise NotImplementedError( + "Slicing with step > 1 is not implemented for scale tensors." + ) + + return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) + + ########################## # Float8 Dispatch Kernels ########################## @@ -333,13 +401,12 @@ def _linear_fp8_act_fp8_weight_impl( input_scale = input_tensor.tensor_impl.scale # Handle case where input tensor is more than 2D inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) - # Handle rowwise case if _is_rowwise_scaled(weight_tensor): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" ) - w_scale = w_scale.unsqueeze(-1).T + w_scale = w_scale.T input_scale = preprocess_scale(input_scale, input_tensor.shape) # Preprocess data diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 7f73a604d7..00c905f3d8 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -7,11 +7,20 @@ Defines an nn module designed to be used during inference """ -from typing import NamedTuple, Optional, Tuple +from typing import NamedTuple, Optional, Tuple, Union import torch from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul +from torchao.quantization.granularity import ( + PerRow, + PerTensor, +) +from torchao.utils import ( + is_MI300, + is_sm_at_least_89, + is_sm_at_least_90, +) Tensor = torch.Tensor @@ -106,3 +115,66 @@ def _is_rowwise_scaled(x) -> bool: x: AffineQuantizedTensor tensor """ return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],) + + +FP8Granularity = Union[PerTensor, PerRow] + + +def _normalize_granularity( + granularity: Optional[ + Union[ + FP8Granularity, + Tuple[FP8Granularity, FP8Granularity], + list[FP8Granularity], + ] + ], +) -> Tuple[FP8Granularity, FP8Granularity]: + processed_granularity = None + if granularity is None: + processed_granularity = (PerTensor(), PerTensor()) + elif isinstance(granularity, (PerTensor, PerRow)): + processed_granularity = (granularity, granularity) + elif isinstance(granularity, (tuple, list)) and len(granularity) == 2: + if not ( + isinstance(granularity[0], (PerTensor, PerRow)) + and isinstance(granularity[1], (PerTensor, PerRow)) + ): + raise ValueError( + f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." + ) + if not isinstance(granularity[0], type(granularity[1])): + raise ValueError( + f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + ) + processed_granularity = tuple(granularity) + else: + raise ValueError( + f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." + ) + return processed_granularity + + +def _check_hardware_support( + granularities: Tuple[FP8Granularity, FP8Granularity], +) -> None: + """ + Validate that the hardware supports the requested granularities. + + Args: + granularities: Tuple of (activation_granularity, weight_granularity) + + Raises: + AssertionError: If hardware doesn't support the requested granularity + ValueError: If invalid granularity type is provided + """ + for _granularity in granularities: + if isinstance(_granularity, PerTensor): + assert is_sm_at_least_89() or is_MI300(), ( + "PerTensor quantization only works for CUDA>=8.9 and MI300+" + ) + elif isinstance(_granularity, PerRow): + assert is_sm_at_least_90() or is_MI300(), ( + "PerRow quantization only works for CUDA>=9.0 and MI300+" + ) + else: + raise ValueError(f"Invalid granularity type: {_granularity}") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4229577b95..a9e5725ec0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,7 +54,12 @@ from torchao.dtypes.utils import Layout from torchao.float8.config import e4m3_dtype, e5m2_dtype from torchao.float8.float8_linear import Float8Linear -from torchao.float8.inference import Float8MMConfig +from torchao.float8.inference import ( + Float8MMConfig, + FP8Granularity, + _check_hardware_support, + _normalize_granularity, +) from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, ) @@ -1431,56 +1436,9 @@ def _float8_weight_only_transform( return module -_fp8_granularities = Union[PerTensor, PerRow] - - -# Validate and process granularity input -def _normalize_granularity( - granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] - ], -) -> Tuple[_fp8_granularities, _fp8_granularities]: - processed_granularity = None - if granularity is None: - processed_granularity = (PerTensor(), PerTensor()) - elif isinstance(granularity, (PerTensor, PerRow)): - processed_granularity = (granularity, granularity) - elif isinstance(granularity, tuple) and len(granularity) == 2: - if not ( - isinstance(granularity[0], (PerTensor, PerRow)) - and isinstance(granularity[1], (PerTensor, PerRow)) - ): - raise ValueError( - f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." - ) - if not isinstance(granularity[0], type(granularity[1])): - raise ValueError( - f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." - ) - processed_granularity = granularity - else: - raise ValueError( - f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." - ) - # Validate granularity with supported Hardware - for _granularity in processed_granularity: - if isinstance(_granularity, PerTensor): - assert is_sm_at_least_89() or is_MI300(), ( - "PerTensor quantization only works for CUDA>=8.9 and MI300+" - ) - elif isinstance(_granularity, PerRow): - assert is_sm_at_least_90() or is_MI300(), ( - "PerRow quantization only works for CUDA>=9.0 and MI300+" - ) - else: - raise ValueError(f"Invalid granularity type: {_granularity}") - - return processed_granularity - - def _input_activation_quant_func_fp8( x: torch.Tensor, - activation_granularity: _fp8_granularities, + activation_granularity: FP8Granularity, activation_dtype: torch.dtype, scale: Optional[torch.Tensor] = None, zero_point: Optional[torch.Tensor] = None, @@ -1567,7 +1525,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): activation_dtype: torch.dtype = e4m3_dtype weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] ] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True @@ -1576,6 +1534,11 @@ def __post_init__(self): if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) + activation_granularity, weight_granularity = _normalize_granularity( + self.granularity + ) + self.granularity = (activation_granularity, weight_granularity) + # for bc float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig @@ -1587,7 +1550,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): granularity = config.granularity mm_config = config.mm_config - activation_granularity, weight_granularity = _normalize_granularity(granularity) + # Ensure works on device + _check_hardware_support(granularity) + activation_granularity, weight_granularity = granularity if not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently @@ -1704,7 +1669,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): activation_dtype: torch.dtype = e4m3_dtype weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] ] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1885755608..cee8df21a2 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1970,6 +1970,7 @@ def choose_qparams_affine_float8( tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, scale_dtype: torch.dtype = torch.float32, + block_size: Optional[Tuple[int, ...]] = None, ) -> torch.Tensor: """ Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. @@ -1977,14 +1978,84 @@ def choose_qparams_affine_float8( Args: tensor (torch.Tensor): Input tensor to be quantized. float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + scale_dtype (torch.dtype): Data type of the scaling factor (e.g., torch.float32). + block_size (Optional[Tuple[int, ...]]): Block size for block-wise quantization. If None, tensorwise quantization is used. """ + quant_max = torch.finfo(float8_dtype).max # only tensorwise scaling is supported for now: - quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max - min_val_neg = torch.min(tensor) - max_val_pos = torch.max(tensor) - max_val_pos = torch.max(-min_val_neg, max_val_pos) - scale = max_val_pos / (float(quant_max - quant_min) / 2) - return scale.to(dtype=scale_dtype) + if block_size is None: + max_abs = tensor.abs().max() + scale = max_abs / quant_max + else: + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, tensor.shape + ) + tensor_reshaped = tensor.view(shape_for_reduction) + max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=True) + + scale = max_abs / quant_max + # Reshape scale back to match the expected output shape + # The scale tensor should have the same shape as the input divided by block_size + output_shape = [ + input_size // block_size[i] for i, input_size in enumerate(tensor.shape) + ] + scale = scale.reshape(output_shape) + + if scale_dtype is not torch.float32: + # Shielding for Version > 2.8 + assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported" + scale = torch.exp2(torch.round(torch.log2(scale))) + return scale.to(dtype=torch.float32) + + +def _expand_scale_to_tensor_shape( + scale: torch.Tensor, target_shape: torch.Size +) -> torch.Tensor: + """ + Expand a scale tensor to match the target tensor shape for block-wise quantization. + + Args: + scale (torch.Tensor): Scale tensor with shape corresponding to block structure + target_shape (torch.Size): Target tensor shape to expand to + + Returns: + torch.Tensor: Scale tensor expanded to match target_shape + """ + if scale.shape == target_shape: + # Scale already matches target shape + return scale + + if scale.numel() == 1: + # Scalar scale - can broadcast naturally + return scale + + # Calculate block sizes from shape difference + if len(scale.shape) != len(target_shape): + raise ValueError( + f"Scale tensor has {len(scale.shape)} dimensions but target has {len(target_shape)}" + ) + + block_sizes = tuple( + target_shape[i] // scale.shape[i] for i in range(len(target_shape)) + ) + + # Verify that target_shape is evenly divisible by scale.shape + for i, (target_dim, scale_dim, block_size) in enumerate( + zip(target_shape, scale.shape, block_sizes) + ): + if target_dim != scale_dim * block_size: + raise ValueError( + f"Dimension {i}: target size {target_dim} is not evenly divisible " + f"by scale size {scale_dim} (block size would be {target_dim / scale_dim})" + ) + + # Expand scale using repeat_interleave + expanded_scale = scale + for i, block_size in enumerate(block_sizes): + if block_size > 1: + expanded_scale = expanded_scale.repeat_interleave(block_size, dim=i) + + return expanded_scale def quantize_affine_float8( @@ -1994,16 +2065,13 @@ def quantize_affine_float8( ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. - - Args: - tensor (torch.Tensor): Input tensor to be quantized. - scale (torch.Tensor): Scaling factor for the quantization. - float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). """ - # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically - # upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization. - # In order to match numerics between eager and compile, we upcast manually here. - tensor_scaled = tensor.to(torch.float32) / scale + tensor_fp32 = tensor.to(torch.float32) + + # Expand scale to match tensor dimensions for block-wise quantization + scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) + + tensor_scaled = tensor_fp32 / scale_expanded max_value = torch.finfo(float8_dtype).max tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) fp8_tensor = tensor_clamped.to(float8_dtype) @@ -2017,15 +2085,11 @@ def dequantize_affine_float8( ) -> torch.Tensor: """ Dequantizes the float8 tensor to high precision tensor. - - Args: - tensor (torch.Tensor): Input float8 tensor to be dequantized. - scale (torch.Tensor): Scaling factor for the dequantization. - output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). """ - # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically - # upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization. - # In order to match numerics between eager and compile, we upcast manually here. fp8_tensor = tensor.to(torch.float32) - hp_tensor = fp8_tensor * scale + + # Expand scale to match tensor dimensions for block-wise quantization + scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) + + hp_tensor = fp8_tensor * scale_expanded return hp_tensor.to(output_dtype)