From d22a1377e29556c57f9e2f70d903fe97f95a82ca Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 18 Mar 2025 23:58:27 +0000 Subject: [PATCH 01/21] initial commit --- .../quantization/lifecycle/forward.py | 46 +++++++++++-------- .../quantization/lifecycle/initialize.py | 30 +++++++++++- .../quantization/quant_args.py | 26 +++++++++++ .../quantization/utils/helpers.py | 27 +++++++---- 4 files changed, 101 insertions(+), 28 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index f4f93f27..0edcc162 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -20,6 +20,7 @@ from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, + QuantizationType, round_to_quantized_type, ) from compressed_tensors.quantization.quant_config import QuantizationStatus @@ -359,18 +360,22 @@ def _quantize( dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - scaled = x / scale - if zero_point is not None: - scaled += zero_point.to(x.dtype) - # clamp first because cast isn't guaranteed to be saturated (ie for fp8) - clamped_value = torch.clamp( - scaled, - q_min, - q_max, - ) - quantized_value = round_to_quantized_type(clamped_value, args) - if dtype is not None: - quantized_value = quantized_value.to(dtype) + if args.num_bits == 4 and args.type == QuantizationType.FLOAT: + # apply fp4 quant + return quantized_value + else: + scaled = x / scale + if zero_point is not None: + scaled += zero_point.to(x.dtype) + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) + clamped_value = torch.clamp( + scaled, + q_min, + q_max, + ) + quantized_value = round_to_quantized_type(clamped_value, args) + if dtype is not None: + quantized_value = quantized_value.to(dtype) return quantized_value @@ -382,13 +387,18 @@ def _dequantize( zero_point: torch.Tensor = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - dequant_value = x_q.to(scale.dtype) - if zero_point is not None: - dequant_value = dequant_value - zero_point.to(scale.dtype) - dequant_value = dequant_value * scale + if args.num_bits == 4 and args.type == QuantizationType.FLOAT: + # apply fp4 deqquant + dequant_value = None + else: + dequant_value = x_q.to(scale.dtype) + + if zero_point is not None: + dequant_value = dequant_value - zero_point.to(scale.dtype) + dequant_value = dequant_value * scale - if dtype is not None: - dequant_value = dequant_value.to(dtype) + if dtype is not None: + dequant_value = dequant_value.to(dtype) return dequant_value diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 6886423a..bc37d84e 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -30,6 +30,8 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from compressed_tensors.utils import ( + FP4_NVFP4_DATA, + FP8_E4M3_DATA, disable_hf_hook, has_offloaded_params, register_offload_parameter, @@ -161,7 +163,33 @@ def _initialize_scale_zero_point( expected_shape = (weight_shape[0], max(num_groups, 1)) scale_dtype = module.weight.dtype - if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: + + # NVFP4 support; use FP8 scales + # For weight quant, attach global scales for NVFP4 + if ( + base_name == "weight" + and quantization_args.num_bits == 4 + and quantization_args.strategy == QuantizationStrategy.FLOAT + ): + scale_dtype = FP8_E4M3_DATA.dtype + # create and attach nvfp4 data + tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) + # Setting data for now - could possibly be handled later in the pipeline + values = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax + # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? + init_global_scale = Parameter( + value, dtype=torch.float32, device=device, requires_grad=False + ) + register_offload_parameter( + module, f"f{base_name}_global_scale", init_global_scale + ) + + if scale_dtype not in [ + torch.float16, + torch.bfloat16, + torch.float32, + FP8_DATA.dtype, + ]: scale_dtype = torch.float16 # initializes empty scale, zero point, and g_idx parameters for the module diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 69c289d2..2da97340 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -13,6 +13,7 @@ # limitations under the License. import warnings +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, Union @@ -24,6 +25,8 @@ __all__ = [ "FP8_DTYPE", + "FP8_E4M3_DATA", + "FP4_NVFP4_DATA", "QuantizationType", "QuantizationStrategy", "QuantizationArgs", @@ -31,8 +34,29 @@ "ActivationOrdering", ] +# TODO: Remove soon in favour of a more descriptive FloatArgs FP8_DTYPE = torch.float8_e4m3fn +FP8_E4M3_DATA = FloatArgs( + exponent=4, + mantissa=3, + bits=8, + max=torch.finfo(torch.float8_e4m3fn).max, + min=torch.finfo(torch.float8_e4m3fn).min, + dtype=torch.float8_e4m3fn, +) +FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) + + +@dataclass +class FloatArgs: + exponent: int + mantissa: int + bits: int + max: float + min: float + dtype: Optional[torch.dtype] = None + class QuantizationType(str, Enum): """ @@ -233,6 +257,8 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: return model def pytorch_dtype(self) -> torch.dtype: + # TODO: required for the compressor + # Add FP4_nvfp4 type when updating naive_compressor if self.type == QuantizationType.FLOAT: return FP8_DTYPE elif self.type == QuantizationType.INT: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 9f65ee33..e5cd3c75 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -17,7 +17,8 @@ import torch from compressed_tensors.quantization.quant_args import ( - FP8_DTYPE, + FP4_NVFP4_DATA, + FP8_E4M3_DATA, QuantizationArgs, QuantizationStrategy, QuantizationType, @@ -73,6 +74,7 @@ def calculate_qparams( zp_dtype = quantization_args.pytorch_dtype() if quantization_args.symmetric: + # TODO: update for NVFP4 when applying observers max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) @@ -138,14 +140,21 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: q_max = torch.tensor(bit_range / 2 - 1, device=device) q_min = torch.tensor(-bit_range / 2, device=device) elif quantization_args.type == QuantizationType.FLOAT: - if quantization_args.num_bits != 8: - raise ValueError( - "Floating point quantization is only supported for 8 bits," - f"got {quantization_args.num_bits}" - ) - fp_range_info = torch.finfo(FP8_DTYPE) - q_max = torch.tensor(fp_range_info.max, device=device) - q_min = torch.tensor(fp_range_info.min, device=device) + if quantization_args.num_bits == 8: + """ + if quantization_args.num_bits != 8: + raise ValueError( + "Floating point quantization is only supported for 8 bits," + f"got {quantization_args.num_bits}" + ) + """ + q_max = torch.tensor(FP8_E4M3_DATA.max, device=device) + q_min = torch.tensor(FP8_E4M3_DATA.min, device=device) + else: + # nvfp4 ranges + assert quantization_args.num_bits == 4 + q_max = torch.tensor(FP4_NVFP4_DATA.max, device=device) + q_min = torch.tensor(FP4_NVFP4_DATA.min, device=device) else: raise ValueError(f"Invalid quantization type {quantization_args.type}") From be02849b8e864bdaf198923ebd3cf8aa75174f47 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 19 Mar 2025 19:01:34 +0000 Subject: [PATCH 02/21] update --- .../quantization/lifecycle/forward.py | 44 ++++++++----------- .../quantization/lifecycle/initialize.py | 5 ++- .../quantization/quant_args.py | 6 ++- .../quantization/utils/helpers.py | 24 ++++++---- 4 files changed, 42 insertions(+), 37 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 0edcc162..f19928d6 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -360,22 +360,18 @@ def _quantize( dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - if args.num_bits == 4 and args.type == QuantizationType.FLOAT: - # apply fp4 quant - return quantized_value - else: - scaled = x / scale - if zero_point is not None: - scaled += zero_point.to(x.dtype) - # clamp first because cast isn't guaranteed to be saturated (ie for fp8) - clamped_value = torch.clamp( - scaled, - q_min, - q_max, - ) - quantized_value = round_to_quantized_type(clamped_value, args) - if dtype is not None: - quantized_value = quantized_value.to(dtype) + scaled = x / scale + if zero_point is not None: + scaled += zero_point.to(x.dtype) + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) + clamped_value = torch.clamp( + scaled, + q_min, + q_max, + ) + quantized_value = round_to_quantized_type(clamped_value, args) + if dtype is not None: + quantized_value = quantized_value.to(dtype) return quantized_value @@ -388,17 +384,13 @@ def _dequantize( dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - if args.num_bits == 4 and args.type == QuantizationType.FLOAT: - # apply fp4 deqquant - dequant_value = None - else: - dequant_value = x_q.to(scale.dtype) + dequant_value = x_q.to(scale.dtype) - if zero_point is not None: - dequant_value = dequant_value - zero_point.to(scale.dtype) - dequant_value = dequant_value * scale + if zero_point is not None: + dequant_value = dequant_value - zero_point.to(scale.dtype) + dequant_value = dequant_value * scale - if dtype is not None: - dequant_value = dequant_value.to(dtype) + if dtype is not None: + dequant_value = dequant_value.to(dtype) return dequant_value diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index bc37d84e..ffdc157a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -25,6 +25,7 @@ ActivationOrdering, QuantizationArgs, QuantizationStrategy, + QuantizationType, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -169,7 +170,7 @@ def _initialize_scale_zero_point( if ( base_name == "weight" and quantization_args.num_bits == 4 - and quantization_args.strategy == QuantizationStrategy.FLOAT + and quantization_args.type == QuantizationType.FLOAT ): scale_dtype = FP8_E4M3_DATA.dtype # create and attach nvfp4 data @@ -188,7 +189,7 @@ def _initialize_scale_zero_point( torch.float16, torch.bfloat16, torch.float32, - FP8_DATA.dtype, + FP8_E4M3_DATA.dtype, ]: scale_dtype = torch.float16 diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 2da97340..397cbe1f 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -289,7 +289,11 @@ def round_to_quantized_type( """ original_dtype = tensor.dtype if args.type == QuantizationType.FLOAT: - rounded = tensor.to(FP8_DTYPE) + if args.num_bits == 8: + rounded = tensor.to(FP8_E4M3_DATA.dtype) + elif args.num_bits == 4: + # TODO: cast to whatever value we want fp4 to be post quantization/clamping + rounded = tensor.to() elif args.type == QuantizationType.INT: rounded = torch.round(tensor) else: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index e5cd3c75..3f6b5328 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -76,7 +76,22 @@ def calculate_qparams( if quantization_args.symmetric: # TODO: update for NVFP4 when applying observers max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) + + if ( + quantization_args.num_bits == 4 + and quantization_args.type == QuantizationType.FLOAT + ): + # TODO: how do we pass in the global scale? + # An observer is attached per module, we can conditionally pass in + # the global scale + scale = global_scale * (max_val_pos / FP4_NVFP4_DATA.max) + scale = scale.to(FP8_E4M3_DATA.dtype).to(torch.float32) + scale = scale / global_scale + else: + # Divide over bit range over max value? + scales = max_val_pos / (float(bit_range) / 2) + + # needed for fp4? scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: @@ -141,13 +156,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: q_min = torch.tensor(-bit_range / 2, device=device) elif quantization_args.type == QuantizationType.FLOAT: if quantization_args.num_bits == 8: - """ - if quantization_args.num_bits != 8: - raise ValueError( - "Floating point quantization is only supported for 8 bits," - f"got {quantization_args.num_bits}" - ) - """ q_max = torch.tensor(FP8_E4M3_DATA.max, device=device) q_min = torch.tensor(FP8_E4M3_DATA.min, device=device) else: From 974953ca96f226066d6d0dd08492d3d1ad35d34d Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 19 Mar 2025 19:08:06 +0000 Subject: [PATCH 03/21] update --- src/compressed_tensors/quantization/quant_args.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 397cbe1f..ccd5b44c 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -260,7 +260,12 @@ def pytorch_dtype(self) -> torch.dtype: # TODO: required for the compressor # Add FP4_nvfp4 type when updating naive_compressor if self.type == QuantizationType.FLOAT: - return FP8_DTYPE + if self.num_bits == 8: + return FP8_E4M3_DATA.dtype + else: + assert self.num_bits == 4 + # TODO: will return None for now until updated in FloatArgs + return FP4_NVFP4_DATA.dtype elif self.type == QuantizationType.INT: if self.num_bits <= 8: return torch.int8 @@ -291,9 +296,10 @@ def round_to_quantized_type( if args.type == QuantizationType.FLOAT: if args.num_bits == 8: rounded = tensor.to(FP8_E4M3_DATA.dtype) - elif args.num_bits == 4: + else: + assert args.num_bits == 4 # TODO: cast to whatever value we want fp4 to be post quantization/clamping - rounded = tensor.to() + rounded = tensor.to(FP4_NVFP4_DATA.dtype) elif args.type == QuantizationType.INT: rounded = torch.round(tensor) else: From 79437ef1e32e6580eed9774440b3a03434d9bd16 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 19 Mar 2025 19:25:03 +0000 Subject: [PATCH 04/21] update --- .../model_compressors/model_compressor.py | 6 +++--- .../quantization/lifecycle/initialize.py | 11 +++++----- .../quantization/quant_args.py | 21 ++++++++++--------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 618b49ee..4139ef93 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -374,9 +374,9 @@ def compress( compressed_state_dict = state_dict - quantized_modules_to_args: Dict[ - str, QuantizationArgs - ] = map_modules_to_quant_args(model) + quantized_modules_to_args: Dict[str, QuantizationArgs] = ( + map_modules_to_quant_args(model) + ) if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index ffdc157a..ffd734eb 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -22,6 +22,8 @@ wrap_module_forward_quantized, ) from compressed_tensors.quantization.quant_args import ( + FP4_NVFP4_DATA, + FP8_E4M3_DATA, ActivationOrdering, QuantizationArgs, QuantizationStrategy, @@ -31,8 +33,6 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from compressed_tensors.utils import ( - FP4_NVFP4_DATA, - FP8_E4M3_DATA, disable_hf_hook, has_offloaded_params, register_offload_parameter, @@ -176,11 +176,10 @@ def _initialize_scale_zero_point( # create and attach nvfp4 data tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) # Setting data for now - could possibly be handled later in the pipeline - values = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax + value = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax + value = value.to(torch.float32).to(device) # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? - init_global_scale = Parameter( - value, dtype=torch.float32, device=device, requires_grad=False - ) + init_global_scale = Parameter(value, requires_grad=False) register_offload_parameter( module, f"f{base_name}_global_scale", init_global_scale ) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index ccd5b44c..100544ff 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -34,6 +34,17 @@ "ActivationOrdering", ] + +@dataclass +class FloatArgs: + exponent: int + mantissa: int + bits: int + max: float + min: float + dtype: Optional[torch.dtype] = None + + # TODO: Remove soon in favour of a more descriptive FloatArgs FP8_DTYPE = torch.float8_e4m3fn @@ -48,16 +59,6 @@ FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) -@dataclass -class FloatArgs: - exponent: int - mantissa: int - bits: int - max: float - min: float - dtype: Optional[torch.dtype] = None - - class QuantizationType(str, Enum): """ Enum storing quantization type options From 36204f0f3e7ace8d7e01259b5bf51c979203e546 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 31 Mar 2025 21:28:26 +0000 Subject: [PATCH 05/21] update quant/dequant steps; update scale calculation step --- .../quantization/lifecycle/forward.py | 55 ++++++++++++++----- .../quantization/lifecycle/initialize.py | 3 + .../quantization/quant_args.py | 3 + .../quantization/utils/helpers.py | 15 +++-- 4 files changed, 57 insertions(+), 19 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index f19928d6..1e69ee1f 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -50,6 +50,7 @@ def quantize( args: QuantizationArgs, dtype: Optional[torch.dtype] = None, g_idx: Optional[torch.Tensor] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Quantize the input tensor x using the QuantizationStrategy specified in args. @@ -76,6 +77,7 @@ def quantize( do_quantize=True, do_dequantize=False, g_idx=g_idx, + global_scale=global_scale, ) @@ -87,6 +89,7 @@ def dequantize( args: Optional[QuantizationArgs] = None, dtype: Optional[torch.dtype] = None, g_idx: Optional[torch.Tensor] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Dequantize a quantized input tensor x_q based on the strategy specified in args. If @@ -129,6 +132,7 @@ def dequantize( do_dequantize=True, dtype=dtype, g_idx=g_idx, + global_scale=global_scale, ) @@ -139,6 +143,7 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None, + global_scale: Optiona[torch.Tensor] = None, ) -> torch.Tensor: """ Fake quantize the input tensor x by quantizing then dequantizing with @@ -162,6 +167,7 @@ def fake_quantize( do_quantize=True, do_dequantize=True, g_idx=g_idx, + global_scale=global_scale, ) @@ -175,6 +181,7 @@ def _process_quantization( dtype: Optional[torch.dtype] = None, do_quantize: bool = True, do_dequantize: bool = True, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: q_min, q_max = calculate_range(args, x.device) group_size = args.group_size @@ -222,18 +229,21 @@ def _process_quantization( end = start + group_count if do_quantize: output[:, start:end] = _quantize( - x[:, start:end], - sc, - zp, - q_min, - q_max, - args, + x=x[:, start:end], + scale=sc, + zero_point=zp, + q_min=q_min, + q_max=q_max, + args=args, dtype=dtype, + global_scale=global_scale, ) if do_dequantize: input = output[:, start:end] if do_quantize else x[:, start:end] - output[:, start:end] = _dequantize(input, sc, zp) + output[:, start:end] = _dequantize( + x=input, scale=sc, zero_point=zp, global_scale=global_scale + ) if not is_column_order: output = safe_permute(output, torch.argsort(perm), dim=1) @@ -241,16 +251,22 @@ def _process_quantization( else: # covers channel, token and tensor strategies if do_quantize: output = _quantize( - x, - scale, - zero_point, - q_min, - q_max, - args, + x=x, + scale=scale, + zero_point=zero_point, + q_min=q_min, + q_max=q_max, + args=args, dtype=dtype, + global_scale=global_scale, ) if do_dequantize: - output = _dequantize(output if do_quantize else x, scale, zero_point) + output = _dequantize( + output if do_quantize else x, + scale=scale, + zero_point=zero_point, + global_scale=global_scale, + ) return output @@ -331,6 +347,7 @@ def forward_quantize( return value g_idx = getattr(module, "weight_g_idx", None) + global_scale = getattr(module, f"{base_name}_global_scale", None) if args.dynamic: # dynamic quantization - determine the scale/zp on the fly @@ -346,6 +363,7 @@ def forward_quantize( zero_point=zero_point, args=args, g_idx=g_idx, + global_scale=global_scale, ) @@ -358,11 +376,16 @@ def _quantize( q_max: torch.Tensor, args: QuantizationArgs, dtype: Optional[torch.dtype] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if global_scale: + scale = scale.to(global_scale.dtype) * global_scale + scaled = x / scale if zero_point is not None: scaled += zero_point.to(x.dtype) + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) clamped_value = torch.clamp( scaled, @@ -382,8 +405,12 @@ def _dequantize( scale: torch.Tensor, zero_point: torch.Tensor = None, dtype: Optional[torch.dtype] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if global_scale: + scale = scale.to(global_scale.dtype) * global_scale + dequant_value = x_q.to(scale.dtype) if zero_point is not None: diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index ffd734eb..dee7a52c 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -167,6 +167,8 @@ def _initialize_scale_zero_point( # NVFP4 support; use FP8 scales # For weight quant, attach global scales for NVFP4 + # TODO: How do we know if we need a global scale? + # TODO: NVFP4 Scheme if ( base_name == "weight" and quantization_args.num_bits == 4 @@ -177,6 +179,7 @@ def _initialize_scale_zero_point( tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) # Setting data for now - could possibly be handled later in the pipeline value = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax + # use the weight dtype (bfloat) maybe use float32 to start? value = value.to(torch.float32).to(device) # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? init_global_scale = Parameter(value, requires_grad=False) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 100544ff..d9f3ca26 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -56,6 +56,7 @@ class FloatArgs: min=torch.finfo(torch.float8_e4m3fn).min, dtype=torch.float8_e4m3fn, ) +# Don't call NVFP4; should be based on exponent and mantissa FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) @@ -265,6 +266,7 @@ def pytorch_dtype(self) -> torch.dtype: return FP8_E4M3_DATA.dtype else: assert self.num_bits == 4 + # TODO: Use the look-up? # TODO: will return None for now until updated in FloatArgs return FP4_NVFP4_DATA.dtype elif self.type == QuantizationType.INT: @@ -299,6 +301,7 @@ def round_to_quantized_type( rounded = tensor.to(FP8_E4M3_DATA.dtype) else: assert args.num_bits == 4 + # TODO: Use the FP4_NVFP4_DATA class to use a look-up table # TODO: cast to whatever value we want fp4 to be post quantization/clamping rounded = tensor.to(FP4_NVFP4_DATA.dtype) elif args.type == QuantizationType.INT: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 3f6b5328..aee87793 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -55,7 +55,10 @@ def calculate_qparams( - min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs + min_vals: Tensor, + max_vals: Tensor, + quantization_args: QuantizationArgs, + global_scale: Optional[Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: """ :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) @@ -81,17 +84,19 @@ def calculate_qparams( quantization_args.num_bits == 4 and quantization_args.type == QuantizationType.FLOAT ): + assert global_scale is not None # TODO: how do we pass in the global scale? # An observer is attached per module, we can conditionally pass in - # the global scale - scale = global_scale * (max_val_pos / FP4_NVFP4_DATA.max) - scale = scale.to(FP8_E4M3_DATA.dtype).to(torch.float32) + # the global scale --> TODO: check for presence of the global when updating the scale + # TODO: maybe remove FP8 scale cast + scale = max_val_pos / FP4_NVFP4_DATA.max scale = scale / global_scale + scale = scale.to(FP8_E4M3_DATA.dtype) # .to(torch.float32) + else: # Divide over bit range over max value? scales = max_val_pos / (float(bit_range) / 2) - # needed for fp4? scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: From d49830d43f6049fb0bc336b82e9963111637ffe0 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 1 Apr 2025 00:58:32 +0000 Subject: [PATCH 06/21] update NVFP4 data type; add scheme --- .../quantization/lifecycle/initialize.py | 10 ++++------ src/compressed_tensors/quantization/quant_args.py | 5 ++--- src/compressed_tensors/quantization/quant_scheme.py | 12 ++++++++++++ src/compressed_tensors/quantization/utils/helpers.py | 12 ++++-------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index dee7a52c..2361e031 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -22,7 +22,7 @@ wrap_module_forward_quantized, ) from compressed_tensors.quantization.quant_args import ( - FP4_NVFP4_DATA, + FP4_E2M1_DATA, FP8_E4M3_DATA, ActivationOrdering, QuantizationArgs, @@ -167,19 +167,17 @@ def _initialize_scale_zero_point( # NVFP4 support; use FP8 scales # For weight quant, attach global scales for NVFP4 - # TODO: How do we know if we need a global scale? # TODO: NVFP4 Scheme if ( - base_name == "weight" - and quantization_args.num_bits == 4 + quantization_args.num_bits == 4 and quantization_args.type == QuantizationType.FLOAT ): scale_dtype = FP8_E4M3_DATA.dtype # create and attach nvfp4 data tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) # Setting data for now - could possibly be handled later in the pipeline - value = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax - # use the weight dtype (bfloat) maybe use float32 to start? + value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax + # TODO: use model.weight.dtype value = value.to(torch.float32).to(device) # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? init_global_scale = Parameter(value, requires_grad=False) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9f3ca26..a1cd1845 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -26,7 +26,7 @@ __all__ = [ "FP8_DTYPE", "FP8_E4M3_DATA", - "FP4_NVFP4_DATA", + "FP4_E2M1_DATA", "QuantizationType", "QuantizationStrategy", "QuantizationArgs", @@ -56,8 +56,7 @@ class FloatArgs: min=torch.finfo(torch.float8_e4m3fn).min, dtype=torch.float8_e4m3fn, ) -# Don't call NVFP4; should be based on exponent and mantissa -FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) +FP4_E2M1_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) class QuantizationType(str, Enum): diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 36b88604..6ea3942d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -100,6 +100,17 @@ def is_preset_scheme(name: str) -> bool: UNQUANTIZED = dict() +NVFP4 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=16, + ) +) + # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( @@ -212,4 +223,5 @@ def is_preset_scheme(name: str) -> bool: # Float weight and activation schemes "FP8": FP8, "FP8_DYNAMIC": FP8_DYNAMIC, + "NVFP4": NVFP4, } diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index aee87793..c1ef90e1 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -17,7 +17,7 @@ import torch from compressed_tensors.quantization.quant_args import ( - FP4_NVFP4_DATA, + FP4_E2M1_DATA, FP8_E4M3_DATA, QuantizationArgs, QuantizationStrategy, @@ -85,11 +85,7 @@ def calculate_qparams( and quantization_args.type == QuantizationType.FLOAT ): assert global_scale is not None - # TODO: how do we pass in the global scale? - # An observer is attached per module, we can conditionally pass in - # the global scale --> TODO: check for presence of the global when updating the scale - # TODO: maybe remove FP8 scale cast - scale = max_val_pos / FP4_NVFP4_DATA.max + scale = max_val_pos / FP4_E2M1_DATA.max # Not needed scale = scale / global_scale scale = scale.to(FP8_E4M3_DATA.dtype) # .to(torch.float32) @@ -166,8 +162,8 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: else: # nvfp4 ranges assert quantization_args.num_bits == 4 - q_max = torch.tensor(FP4_NVFP4_DATA.max, device=device) - q_min = torch.tensor(FP4_NVFP4_DATA.min, device=device) + q_max = torch.tensor(FP4_E2M1_DATA.max, device=device) + q_min = torch.tensor(FP4_E2M1_DATA.min, device=device) else: raise ValueError(f"Invalid quantization type {quantization_args.type}") From 925482110592f32027b0c33575b4f4f8248cd929 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 1 Apr 2025 01:18:02 +0000 Subject: [PATCH 07/21] update datatype/look-up table --- .../quantization/lifecycle/forward.py | 2 +- .../quantization/quant_args.py | 27 ++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 1e69ee1f..dfbbf563 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -143,7 +143,7 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None, - global_scale: Optiona[torch.Tensor] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Fake quantize the input tensor x by quantizing then dequantizing with diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index a1cd1845..12d7f72f 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -45,6 +45,22 @@ class FloatArgs: dtype: Optional[torch.dtype] = None +@dataclass +class FloatArgsFP4E2M1(FloatArgs): + def cast_to_fp4(self, x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + # TODO: Remove soon in favour of a more descriptive FloatArgs FP8_DTYPE = torch.float8_e4m3fn @@ -56,7 +72,8 @@ class FloatArgs: min=torch.finfo(torch.float8_e4m3fn).min, dtype=torch.float8_e4m3fn, ) -FP4_E2M1_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) + +FP4_E2M1_DATA = FloatArgsFP4E2M1(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) class QuantizationType(str, Enum): @@ -265,9 +282,7 @@ def pytorch_dtype(self) -> torch.dtype: return FP8_E4M3_DATA.dtype else: assert self.num_bits == 4 - # TODO: Use the look-up? - # TODO: will return None for now until updated in FloatArgs - return FP4_NVFP4_DATA.dtype + raise NotImplementedError("Not supported for FP4") elif self.type == QuantizationType.INT: if self.num_bits <= 8: return torch.int8 @@ -300,9 +315,7 @@ def round_to_quantized_type( rounded = tensor.to(FP8_E4M3_DATA.dtype) else: assert args.num_bits == 4 - # TODO: Use the FP4_NVFP4_DATA class to use a look-up table - # TODO: cast to whatever value we want fp4 to be post quantization/clamping - rounded = tensor.to(FP4_NVFP4_DATA.dtype) + rounded = FP4_E2M1_DATA.cast_to_fp4(tensor) elif args.type == QuantizationType.INT: rounded = torch.round(tensor) else: From 172750861a798697be8e9396b5f6bd60e1bce177 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 1 Apr 2025 01:22:30 +0000 Subject: [PATCH 08/21] fix param name --- src/compressed_tensors/quantization/lifecycle/forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index dfbbf563..c2fb08eb 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -242,7 +242,7 @@ def _process_quantization( if do_dequantize: input = output[:, start:end] if do_quantize else x[:, start:end] output[:, start:end] = _dequantize( - x=input, scale=sc, zero_point=zp, global_scale=global_scale + x_q=input, scale=sc, zero_point=zp, global_scale=global_scale ) if not is_column_order: From eec7bd3c45eb630a6763833d9381bafdf22127d2 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 1 Apr 2025 17:48:40 +0000 Subject: [PATCH 09/21] update --- .../quantization/lifecycle/initialize.py | 13 ++++++++++--- .../quantization/utils/helpers.py | 15 ++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2361e031..c357e5f2 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -177,12 +177,12 @@ def _initialize_scale_zero_point( tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) # Setting data for now - could possibly be handled later in the pipeline value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax - # TODO: use model.weight.dtype + # TODO: use model.weight.dtype after checking value = value.to(torch.float32).to(device) # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? init_global_scale = Parameter(value, requires_grad=False) register_offload_parameter( - module, f"f{base_name}_global_scale", init_global_scale + module, f"{base_name}_global_scale", init_global_scale ) if scale_dtype not in [ @@ -201,7 +201,14 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: - zp_dtype = quantization_args.pytorch_dtype() + if ( + quantization_args.num_bits == 4 + and quantization_args.type == QuantizationType.FLOAT + ): + zp_dtype = FP8_E4M3_DATA.dtype + else: + zp_dtype = quantization_args.pytorch_dtype() + init_zero_point = Parameter( torch.zeros(expected_shape, device=device, dtype=zp_dtype), requires_grad=False, diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index c1ef90e1..15a80157 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -74,7 +74,9 @@ def calculate_qparams( bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min - zp_dtype = quantization_args.pytorch_dtype() + # TODO: update + # zp_dtype = quantization_args.pytorch_dtype() + zp_dtype = FP8_E4M3_DATA.dtype if quantization_args.symmetric: # TODO: update for NVFP4 when applying observers @@ -85,15 +87,18 @@ def calculate_qparams( and quantization_args.type == QuantizationType.FLOAT ): assert global_scale is not None - scale = max_val_pos / FP4_E2M1_DATA.max # Not needed - scale = scale / global_scale - scale = scale.to(FP8_E4M3_DATA.dtype) # .to(torch.float32) + breakpoint() + scales = max_val_pos / FP4_E2M1_DATA.max # Not needed + scales = scales / global_scale + scales = scales.to(FP8_E4M3_DATA.dtype) # .to(torch.float32) else: # Divide over bit range over max value? scales = max_val_pos / (float(bit_range) / 2) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + # TODO: clamp not implemented for FP8 ' + breakpoint() + # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: scales = (max_vals - min_vals) / float(bit_range) From b11b96a0d42c018aac2785d6e12ddaf653d4e626 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 1 Apr 2025 18:48:27 +0000 Subject: [PATCH 10/21] swap operations --- .../quantization/lifecycle/forward.py | 4 ++-- src/compressed_tensors/quantization/utils/helpers.py | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index c2fb08eb..9fc3a68d 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -380,7 +380,7 @@ def _quantize( ) -> torch.Tensor: if global_scale: - scale = scale.to(global_scale.dtype) * global_scale + scale = scale.to(global_scale.dtype) / global_scale scaled = x / scale if zero_point is not None: @@ -409,7 +409,7 @@ def _dequantize( ) -> torch.Tensor: if global_scale: - scale = scale.to(global_scale.dtype) * global_scale + scale = scale.to(global_scale.dtype) / global_scale dequant_value = x_q.to(scale.dtype) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 15a80157..8fa8a259 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -87,17 +87,14 @@ def calculate_qparams( and quantization_args.type == QuantizationType.FLOAT ): assert global_scale is not None - breakpoint() - scales = max_val_pos / FP4_E2M1_DATA.max # Not needed - scales = scales / global_scale - scales = scales.to(FP8_E4M3_DATA.dtype) # .to(torch.float32) - + scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) # Not needed + # scales = scales / global_scale + scales = scales.to(FP8_E4M3_DATA.dtype) else: # Divide over bit range over max value? - scales = max_val_pos / (float(bit_range) / 2) + scales = max_val_pos / (float(bit_radnge) / 2) # TODO: clamp not implemented for FP8 ' - breakpoint() # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: From e8c6c8fb2b7b383ef4dbc45b8eafe46298347572 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 1 Apr 2025 19:23:08 +0000 Subject: [PATCH 11/21] fix typo --- src/compressed_tensors/quantization/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 8fa8a259..097a5e5f 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -92,7 +92,7 @@ def calculate_qparams( scales = scales.to(FP8_E4M3_DATA.dtype) else: # Divide over bit range over max value? - scales = max_val_pos / (float(bit_radnge) / 2) + scales = max_val_pos / (float(bit_range) / 2) # TODO: clamp not implemented for FP8 ' # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) From be30822f19fa54e05769ea6554a672e3a4321093 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 2 Apr 2025 20:01:16 +0000 Subject: [PATCH 12/21] fix condition --- src/compressed_tensors/quantization/utils/helpers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 097a5e5f..dd8a1a8f 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -88,14 +88,15 @@ def calculate_qparams( ): assert global_scale is not None scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) # Not needed - # scales = scales / global_scale scales = scales.to(FP8_E4M3_DATA.dtype) else: # Divide over bit range over max value? scales = max_val_pos / (float(bit_range) / 2) - # TODO: clamp not implemented for FP8 ' - # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + # TODO: clamp not implemented for FP8 - we shouldn't need to clamp this anyway as we're + # casting to FP8 on line 92? + if scales.dtype != FP8_E4M3_DATA.dtype: + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: scales = (max_vals - min_vals) / float(bit_range) From 682c1102ad18a10efe55a2eccdee8de9e8e93282 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 24 Apr 2025 19:58:48 +0000 Subject: [PATCH 13/21] fix condition --- .../model_compressors/model_compressor.py | 6 +++--- .../quantization/lifecycle/initialize.py | 5 ++--- src/compressed_tensors/quantization/utils/helpers.py | 12 ++++++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index e5ca9916..7a7a5e88 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -377,9 +377,9 @@ def compress( compressed_state_dict = state_dict - quantized_modules_to_args: Dict[str, QuantizationArgs] = ( - map_modules_to_quant_args(model) - ) + quantized_modules_to_args: Dict[ + str, QuantizationArgs + ] = map_modules_to_quant_args(model) if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 3ade60e5..2dc15304 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -193,17 +193,16 @@ def _initialize_scale_zero_point( module, f"{base_name}_global_scale", init_global_scale ) - # TODO: consider erroring out in the future as if the dtype if not one fo these, # there is likely bug - + if scale_dtype not in [ torch.float16, torch.bfloat16, torch.float32, FP8_E4M3_DATA.dtype, ]: - scale_dtype = torch.float16 + scale_dtype = torch.float16 # initializes empty scale, zero point, and g_idx parameters for the module init_scale = Parameter( diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index c6582a01..64fe7186 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -82,7 +82,6 @@ def calculate_qparams( zp_dtype = FP8_E4M3_DATA.dtype if quantization_args.symmetric: - # TODO: update for NVFP4 when applying observers max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) if ( @@ -96,10 +95,15 @@ def calculate_qparams( # Divide over bit range over max value? scales = max_val_pos / (float(bit_range) / 2) - # TODO: clamp not implemented for FP8 - we shouldn't need to clamp this anyway as we're - # casting to FP8 on line 92? - if scales.dtype != FP8_E4M3_DATA.dtype: + if scales.dtype == FP8_E4M3_DATA.dtype: + # use the next largest fp8 value from 0 + # Optionally, we swap to use the reciporcal + scales = torch.where( + scales == 0, torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype), scales + ) + else: scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: scales = (max_vals - min_vals) / float(bit_range) From 35d98d55b6fc3171a8ff9edb8b40e1b544402833 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 24 Apr 2025 22:01:54 +0000 Subject: [PATCH 14/21] per tensor input scales are never good??? --- .../quantization/lifecycle/initialize.py | 29 ++++++++++--------- .../quantization/quant_scheme.py | 10 ++++++- .../quantization/utils/helpers.py | 2 +- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2dc15304..9f305b34 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -175,23 +175,26 @@ def _initialize_scale_zero_point( # NVFP4 support; use FP8 scales # For weight quant, attach global scales for NVFP4 - # TODO: NVFP4 Scheme if ( quantization_args.num_bits == 4 and quantization_args.type == QuantizationType.FLOAT ): - scale_dtype = FP8_E4M3_DATA.dtype - # create and attach nvfp4 data - tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) - # Setting data for now - could possibly be handled later in the pipeline - value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax - # TODO: use model.weight.dtype after checking - value = value.to(torch.float32).to(device) - # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? - init_global_scale = Parameter(value, requires_grad=False) - register_offload_parameter( - module, f"{base_name}_global_scale", init_global_scale - ) + if base_name == "weight": + scale_dtype = FP8_E4M3_DATA.dtype + # create and attach nvfp4 data + tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) + # Setting data for now - could possibly be handled later in the pipeline + value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax + # TODO: use model.weight.dtype after checking + value = value.to(torch.float32).to(device) + # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? + init_global_scale = Parameter(value, requires_grad=False) + register_offload_parameter( + module, f"{base_name}_global_scale", init_global_scale + ) + else: + # input scales should be float32 + scale_dtype = torch.float32 # TODO: consider erroring out in the future as if the dtype if not one fo these, # there is likely bug diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 3d5c9a76..28a19e12 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -108,7 +108,15 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=16, - ) + ), + input_activations=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + observer=None, + ), ) # 8 bit integer weights and 8 bit activations quantization diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 64fe7186..c9429231 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -87,8 +87,8 @@ def calculate_qparams( if ( quantization_args.num_bits == 4 and quantization_args.type == QuantizationType.FLOAT + and global_scale is not None ): - assert global_scale is not None scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) # Not needed scales = scales.to(FP8_E4M3_DATA.dtype) else: From 107bd938872133ab22d91ed843b76e6a9a33ce86 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 28 Apr 2025 20:58:39 +0000 Subject: [PATCH 15/21] remove scheme --- src/compressed_tensors/quantization/quant_scheme.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 28a19e12..3d5c9a76 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -108,15 +108,7 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=16, - ), - input_activations=QuantizationArgs( - num_bits=4, - type=QuantizationType.FLOAT, - strategy=QuantizationStrategy.TENSOR, - symmetric=True, - dynamic=False, - observer=None, - ), + ) ) # 8 bit integer weights and 8 bit activations quantization From daca970086c31c948aecce1644f446279437842e Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 5 May 2025 10:47:21 -0400 Subject: [PATCH 16/21] [WIP][NVFP4] Add compression/decompression code (#291) * add nvfp4 packing * add model_opt compressor * update script * update compress/decompress methods * update * update * update --- .../compressors/compress_to_fp4.py | 126 ++++++++++++ .../quantized_compressors/__init__.py | 1 + .../compressors/quantized_compressors/base.py | 4 + .../modelopt_quantized.py | 180 ++++++++++++++++++ src/compressed_tensors/config/base.py | 1 + 5 files changed, 312 insertions(+) create mode 100644 src/compressed_tensors/compressors/compress_to_fp4.py create mode 100644 src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py diff --git a/src/compressed_tensors/compressors/compress_to_fp4.py b/src/compressed_tensors/compressors/compress_to_fp4.py new file mode 100644 index 00000000..a523f9c0 --- /dev/null +++ b/src/compressed_tensors/compressors/compress_to_fp4.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import torch + + +FLOAT_TO_E2M1 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] +conversion_dict = {} + +# Dictionary between fp4 value and index +for i in range(len(FLOAT_TO_E2M1)): + conversion_dict[FLOAT_TO_E2M1[i]] = i + + +def fp4_to_index(value): + sign = torch.signbit(value) + x = torch.abs(value) + index = conversion_dict.get(x.item()) + + if not sign: # all positives + return index + else: # all negatives + return index + 8 + + +def pack_fp4_values(x: torch.Tensor): + x_flatten = x.flatten() + # convert to index value, unpack to bits + x_index = numpy.array([fp4_to_index(i) for i in x_flatten], dtype=numpy.uint8) + x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)).to("cuda:0") + + packed_shape = ( + torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8).to("cuda:0") + ) + start = 0 + end = 16 + i = 0 + + # janky bit manipulation + while end <= len(x_index_bits): + print(start, end) + subset = x_index_bits[start:end] + + subset_a = subset[4:8] + subset_b = subset[12:16] + packed_shape[i + 4 : i + 8] = subset_a + packed_shape[i : i + 4] = subset_b + start = end + end = start + 16 + i += 8 + + # pack + packed = numpy.packbits(packed_shape.cpu().numpy()) + packed = torch.Tensor(packed).to(torch.uint8) + packed = packed.reshape(m, n // 2) + return packed + + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + +# reference: https://github.com/vllm-project/vllm/pull/16362 +def break_fp4_bytes(a, dtype=torch.float32): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +# fp4 tensor +x = torch.Tensor( + [ + [-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000, -0.0000], + [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, 0.0000], + [-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000], + [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000], + ] +) + +m, n = x.shape + +packed = pack_fp4_values(x) +out = break_fp4_bytes(packed) +assert torch.equal(out, x) # misleading as -0 and 0 are considered equal +sign_bitx = torch.signbit(x) +sign_bitout = torch.signbit(out) +assert torch.equal(sign_bitout, sign_bitx) diff --git a/src/compressed_tensors/compressors/quantized_compressors/__init__.py b/src/compressed_tensors/compressors/quantized_compressors/__init__.py index 51e8b8e2..496519d4 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/__init__.py +++ b/src/compressed_tensors/compressors/quantized_compressors/__init__.py @@ -14,5 +14,6 @@ # flake8: noqa from .base import * +from .modelopt_quantized import * from .naive_quantized import * from .pack_quantized import * diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 098328be..16cdcd7c 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -113,6 +113,9 @@ def compress( scale = model_state.get(merge_names(prefix, "weight_scale"), None) zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) + global_scale = model_state.get( + merge_names(prefix, "weight_global_scale"), None + ) if scale is not None: # weight is quantized, compress it if isinstance(names_to_scheme[prefix], tuple): @@ -125,6 +128,7 @@ def compress( scale=scale, zero_point=zp, g_idx=g_idx, + global_scale=global_scale, quantization_args=quant_args, device="cpu", ) diff --git a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py new file mode 100644 index 00000000..8068b046 --- /dev/null +++ b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py @@ -0,0 +1,180 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, Optional, Tuple + +import numpy +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.quantized_compressors.base import ( + BaseQuantizationCompressor, +) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize +from torch import Tensor + + +FLOAT_TO_E2M1 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] +conversion_dict = {} + +# Dictionary between fp4 value and index +for i in range(len(FLOAT_TO_E2M1)): + conversion_dict[FLOAT_TO_E2M1[i]] = i + + +def fp4_to_index(value): + sign = torch.signbit(value) + x = torch.abs(value) + index = conversion_dict.get(x.item()) + + if not sign: # all positives + return index + else: # all negatives + return index + 8 + + +@BaseCompressor.register(name=CompressionFormat.modelopt_quantized.value) +class ModelOptCompressor(BaseQuantizationCompressor): + """ + Implements naive compression for quantized models. Weight of each + quantized layer is converted from its original float type to the closest Pytorch + type to the type specified by the layer's QuantizationArgs. + """ + + @property + def compression_param_names(self) -> Tuple[str]: + """ + Returns a tuple of compression parameter names introduced by + the compressor during compression + """ + return ( + "weight_packed", + "weight_scale", + "weight_zero_point", + "weight_global_scale", + ) + + def compress_weight( + self, + weight: Tensor, + scale: Tensor, + global_scale: Tensor, + quantization_args: QuantizationArgs, + device: Optional[torch.device] = None, + zero_point: Optional[torch.Tensor] = None, + g_idx: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + + quantized_weight = quantize( + x=weight, + scale=scale, + global_scale=global_scale, + zero_point=zero_point, + args=quantization_args, + ) + compressed_dict = {} + weight_packed = pack_fp4_to_uint8(quantized_weight) + compressed_dict["weight_packed"] = weight_packed + return compressed_dict + + def decompress_weight( + self, + compressed_data: Dict[str, Tensor], + quantization_args: Optional[QuantizationArgs] = None, + ) -> torch.Tensor: + + weight = compressed_data["weight_packed"] + scale = compressed_data["weight_scale"] + global_scale = compressed_data["weight_global_scale"] + m, n = weight.shape + # TODO: need a way to pass in the output_dtype - can't be assumed based on the scales + # for nvfp4 (maybe the global scale can be not fp32?) + unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + decompressed_weight = dequantize( + x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype + ) + + return decompressed_weight + + +def pack_fp4_to_uint8(x: torch.Tensor): + m, n = x.shape + x_flatten = x.flatten() + # convert to index value, unpack to bits + x_index = numpy.array([fp4_to_index(i) for i in x_flatten], dtype=numpy.uint8) + x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)).to("cuda:0") + + packed_shape = ( + torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8).to("cuda:0") + ) + start = 0 + end = 16 + i = 0 + + # janky bit manipulation + while end <= len(x_index_bits): + subset = x_index_bits[start:end] + + subset_a = subset[4:8] + subset_b = subset[12:16] + packed_shape[i + 4 : i + 8] = subset_a + packed_shape[i : i + 4] = subset_b + start = end + end = start + 16 + i += 8 + + # pack + packed = numpy.packbits(packed_shape.cpu().numpy()) + packed = torch.Tensor(packed).to(torch.uint8).to("cuda:0") + packed = packed.reshape(m, n // 2) + return packed + + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + +# reference: : https://github.com/vllm-project/vllm/pull/16362 +def unpack_fp4_from_uint8(a: torch.Tensor, m: int, n: int, dtype=torch.float16): + assert a.dtype == torch.uint8 + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n).to(dtype=dtype) diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 9ca6f2cf..3ec3bc46 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -32,6 +32,7 @@ class CompressionFormat(Enum): naive_quantized = "naive-quantized" pack_quantized = "pack-quantized" marlin_24 = "marlin-24" + modelopt_quantized = "modelopt-quantized" @unique From 7fbf3004714a780a5bfb8a675beea66c155b6ba9 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 5 May 2025 15:21:38 +0000 Subject: [PATCH 17/21] remove script, add tests --- .../compressors/compress_to_fp4.py | 126 ------------------ .../modelopt_quantized.py | 7 +- .../test_modelopt_quant.py | 21 +++ 3 files changed, 25 insertions(+), 129 deletions(-) delete mode 100644 src/compressed_tensors/compressors/compress_to_fp4.py create mode 100644 tests/test_compressors/quantized_compressors/test_modelopt_quant.py diff --git a/src/compressed_tensors/compressors/compress_to_fp4.py b/src/compressed_tensors/compressors/compress_to_fp4.py deleted file mode 100644 index a523f9c0..00000000 --- a/src/compressed_tensors/compressors/compress_to_fp4.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy -import torch - - -FLOAT_TO_E2M1 = [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, -] -conversion_dict = {} - -# Dictionary between fp4 value and index -for i in range(len(FLOAT_TO_E2M1)): - conversion_dict[FLOAT_TO_E2M1[i]] = i - - -def fp4_to_index(value): - sign = torch.signbit(value) - x = torch.abs(value) - index = conversion_dict.get(x.item()) - - if not sign: # all positives - return index - else: # all negatives - return index + 8 - - -def pack_fp4_values(x: torch.Tensor): - x_flatten = x.flatten() - # convert to index value, unpack to bits - x_index = numpy.array([fp4_to_index(i) for i in x_flatten], dtype=numpy.uint8) - x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)).to("cuda:0") - - packed_shape = ( - torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8).to("cuda:0") - ) - start = 0 - end = 16 - i = 0 - - # janky bit manipulation - while end <= len(x_index_bits): - print(start, end) - subset = x_index_bits[start:end] - - subset_a = subset[4:8] - subset_b = subset[12:16] - packed_shape[i + 4 : i + 8] = subset_a - packed_shape[i : i + 4] = subset_b - start = end - end = start + 16 - i += 8 - - # pack - packed = numpy.packbits(packed_shape.cpu().numpy()) - packed = torch.Tensor(packed).to(torch.uint8) - packed = packed.reshape(m, n // 2) - return packed - - -kE2M1ToFloat = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 -) - -# reference: https://github.com/vllm-project/vllm/pull/16362 -def break_fp4_bytes(a, dtype=torch.float32): - assert a.dtype == torch.uint8 - m, n = a.shape - - # Vectorized nibble processing - a_flat = a.flatten() - high = (a_flat & 0xF0) >> 4 # Upper nibbles - low = a_flat & 0x0F # Lower nibbles - - # Combine nibbles for batch processing - combined = torch.stack((low, high), dim=1).flatten() - - # Vectorized sign and magnitude extraction - signs = (combined & 0x08).to(torch.bool) # Sign bits - abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices - - # Device-aware lookup and sign application - kE2M1 = kE2M1ToFloat.to(device=a.device) - values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) - - # Reshape to final form - return values.reshape(m, n * 2).to(dtype=dtype) - - -# fp4 tensor -x = torch.Tensor( - [ - [-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000, -0.0000], - [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, 0.0000], - [-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000], - [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000], - ] -) - -m, n = x.shape - -packed = pack_fp4_values(x) -out = break_fp4_bytes(packed) -assert torch.equal(out, x) # misleading as -0 and 0 are considered equal -sign_bitx = torch.signbit(x) -sign_bitout = torch.signbit(out) -assert torch.equal(sign_bitout, sign_bitx) diff --git a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py index 8068b046..e6bb454e 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py @@ -26,6 +26,7 @@ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize from torch import Tensor +__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"] FLOAT_TO_E2M1 = [ 0.0, @@ -124,10 +125,10 @@ def pack_fp4_to_uint8(x: torch.Tensor): x_flatten = x.flatten() # convert to index value, unpack to bits x_index = numpy.array([fp4_to_index(i) for i in x_flatten], dtype=numpy.uint8) - x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)).to("cuda:0") + x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)) packed_shape = ( - torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8).to("cuda:0") + torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8) ) start = 0 end = 16 @@ -147,7 +148,7 @@ def pack_fp4_to_uint8(x: torch.Tensor): # pack packed = numpy.packbits(packed_shape.cpu().numpy()) - packed = torch.Tensor(packed).to(torch.uint8).to("cuda:0") + packed = torch.Tensor(packed).to(torch.uint8) packed = packed.reshape(m, n // 2) return packed diff --git a/tests/test_compressors/quantized_compressors/test_modelopt_quant.py b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py new file mode 100644 index 00000000..318d434d --- /dev/null +++ b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py @@ -0,0 +1,21 @@ +import torch +from compressed_tensors.compressors.quantized_compressors.modelopt_quantized import pack_fp4_to_uint8, unpack_fp4_from_uint8 + +def test_pack_unpack(): + x = torch.Tensor( + [ + [-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000, -0.0000], + [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, 0.0000], + [-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000], + [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000], + ] + ) + m, n = x.shape + packed = pack_fp4_to_uint8(x) + unpacked = unpack_fp4_from_uint8(packed, m, n) + + assert torch.equal(unpacked, x) # misleading as -0 and 0 are considered equal + sign_bitx = torch.signbit(x) + sign_bitout = torch.signbit(unpacked) + assert torch.equal(sign_bitout, sign_bitx) + From 5544ef43776397aa7f5508bc28c50cc326bcd6b6 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 6 May 2025 12:09:11 -0400 Subject: [PATCH 18/21] Optimize pack_fp4_to_uint8 for fp4 (#309) Signed-off-by: mgoin --- .../modelopt_quantized.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py index e6bb454e..9ffadf72 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py @@ -122,35 +122,35 @@ def decompress_weight( def pack_fp4_to_uint8(x: torch.Tensor): m, n = x.shape - x_flatten = x.flatten() - # convert to index value, unpack to bits - x_index = numpy.array([fp4_to_index(i) for i in x_flatten], dtype=numpy.uint8) - x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)) - - packed_shape = ( - torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8) - ) - start = 0 - end = 16 - i = 0 - - # janky bit manipulation - while end <= len(x_index_bits): - subset = x_index_bits[start:end] - - subset_a = subset[4:8] - subset_b = subset[12:16] - packed_shape[i + 4 : i + 8] = subset_a - packed_shape[i : i + 4] = subset_b - start = end - end = start + 16 - i += 8 - - # pack - packed = numpy.packbits(packed_shape.cpu().numpy()) - packed = torch.Tensor(packed).to(torch.uint8) - packed = packed.reshape(m, n // 2) - return packed + device = x.device + + # Create lookup table for FP4 values to indices + # Map the absolute values to 0-7 indices + kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device) + + # Find closest valid FP4 value index for each element + abs_x = torch.abs(x) + abs_indices = torch.zeros_like(abs_x, dtype=torch.long) + for i, val in enumerate(kE2M1): + abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices) + + # Apply sign bit (bit 3) to get final 4-bit representation + indices = abs_indices + (torch.signbit(x) * 8).to(torch.long) + + # Reshape to prepare for packing pairs of values + indices = indices.reshape(-1) + + # Handle odd length by padding if necessary + if indices.numel() % 2 != 0: + indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) + + # Reshape to pair consecutive elements + indices = indices.reshape(-1, 2) + + # Pack pairs of 4-bit values into 8-bit values + packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8) + + return packed.reshape(m, n // 2) kE2M1ToFloat = torch.tensor( From ddb41ab9d2703b3dee3e9c6fa0843b6d7fa5a85c Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 6 May 2025 17:02:26 +0000 Subject: [PATCH 19/21] fix pack dtype, update test, clean-up --- .../modelopt_quantized.py | 29 +++++-------------- .../test_modelopt_quant.py | 28 ++++++++++++++++-- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py index 9ffadf72..5e56a6b5 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py @@ -26,6 +26,7 @@ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize from torch import Tensor + __all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"] FLOAT_TO_E2M1 = [ @@ -38,22 +39,6 @@ 4.0, 6.0, ] -conversion_dict = {} - -# Dictionary between fp4 value and index -for i in range(len(FLOAT_TO_E2M1)): - conversion_dict[FLOAT_TO_E2M1[i]] = i - - -def fp4_to_index(value): - sign = torch.signbit(value) - x = torch.abs(value) - index = conversion_dict.get(x.item()) - - if not sign: # all positives - return index - else: # all negatives - return index + 8 @BaseCompressor.register(name=CompressionFormat.modelopt_quantized.value) @@ -97,6 +82,8 @@ def compress_weight( ) compressed_dict = {} weight_packed = pack_fp4_to_uint8(quantized_weight) + if device is not None: + weight_packed = weight_packed.to(device) compressed_dict["weight_packed"] = weight_packed return compressed_dict @@ -110,9 +97,9 @@ def decompress_weight( scale = compressed_data["weight_scale"] global_scale = compressed_data["weight_global_scale"] m, n = weight.shape - # TODO: need a way to pass in the output_dtype - can't be assumed based on the scales - # for nvfp4 (maybe the global scale can be not fp32?) - unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + # TODO: we may not always use the global_scale dtype as the detype to dequant + # We need to pass in the pretrained model dtype to the compressors + unpacked = unpack_fp4_from_uint8(weight, m, n * 2, dtype=global_scale.dtype) decompressed_weight = dequantize( x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype ) @@ -126,7 +113,7 @@ def pack_fp4_to_uint8(x: torch.Tensor): # Create lookup table for FP4 values to indices # Map the absolute values to 0-7 indices - kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device) + kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype) # Find closest valid FP4 value index for each element abs_x = torch.abs(x) @@ -158,7 +145,7 @@ def pack_fp4_to_uint8(x: torch.Tensor): ) # reference: : https://github.com/vllm-project/vllm/pull/16362 -def unpack_fp4_from_uint8(a: torch.Tensor, m: int, n: int, dtype=torch.float16): +def unpack_fp4_from_uint8(a: torch.Tensor, m: int, n: int, dtype=torch.bfloat16): assert a.dtype == torch.uint8 # Vectorized nibble processing diff --git a/tests/test_compressors/quantized_compressors/test_modelopt_quant.py b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py index 318d434d..75709033 100644 --- a/tests/test_compressors/quantized_compressors/test_modelopt_quant.py +++ b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py @@ -1,5 +1,23 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch -from compressed_tensors.compressors.quantized_compressors.modelopt_quantized import pack_fp4_to_uint8, unpack_fp4_from_uint8 +from compressed_tensors.compressors.quantized_compressors.modelopt_quantized import ( + pack_fp4_to_uint8, + unpack_fp4_from_uint8, +) + def test_pack_unpack(): x = torch.Tensor( @@ -10,12 +28,16 @@ def test_pack_unpack(): [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000], ] ) + + dense_dtype = torch.bfloat16 + x = x.to(dense_dtype) m, n = x.shape packed = pack_fp4_to_uint8(x) - unpacked = unpack_fp4_from_uint8(packed, m, n) + assert packed.dtype == torch.uint8 + unpacked = unpack_fp4_from_uint8(packed, m, n, dtype=torch.float16) + assert unpacked.dtype == torch.float16 assert torch.equal(unpacked, x) # misleading as -0 and 0 are considered equal sign_bitx = torch.signbit(x) sign_bitout = torch.signbit(unpacked) assert torch.equal(sign_bitout, sign_bitx) - From a95520d535f901305857b7c6da62e68cc4ee3f9a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 7 May 2025 16:42:50 +0000 Subject: [PATCH 20/21] update compressor --- .../compressors/quantized_compressors/modelopt_quantized.py | 2 +- .../quantized_compressors/test_modelopt_quant.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py index 5e56a6b5..aeb3ccca 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py @@ -99,7 +99,7 @@ def decompress_weight( m, n = weight.shape # TODO: we may not always use the global_scale dtype as the detype to dequant # We need to pass in the pretrained model dtype to the compressors - unpacked = unpack_fp4_from_uint8(weight, m, n * 2, dtype=global_scale.dtype) + unpacked = unpack_fp4_from_uint8(weight, m, n * 2) decompressed_weight = dequantize( x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype ) diff --git a/tests/test_compressors/quantized_compressors/test_modelopt_quant.py b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py index 75709033..b5f81e67 100644 --- a/tests/test_compressors/quantized_compressors/test_modelopt_quant.py +++ b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py @@ -34,8 +34,8 @@ def test_pack_unpack(): m, n = x.shape packed = pack_fp4_to_uint8(x) assert packed.dtype == torch.uint8 - unpacked = unpack_fp4_from_uint8(packed, m, n, dtype=torch.float16) - assert unpacked.dtype == torch.float16 + unpacked = unpack_fp4_from_uint8(packed, m, n, dtype=dense_dtype) + assert unpacked.dtype == dense_dtype assert torch.equal(unpacked, x) # misleading as -0 and 0 are considered equal sign_bitx = torch.signbit(x) From 7435b3f2cae99c2014fc7318a6679fd32f345644 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 8 May 2025 20:17:12 +0000 Subject: [PATCH 21/21] update global scale calculation --- .../quantization/lifecycle/apply.py | 59 ++++++++++++++++++- .../quantization/utils/helpers.py | 8 ++- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 0e6c3d5f..0d4aaae0 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -28,7 +28,11 @@ from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + FP4_E2M1_DATA, + FP8_E4M3_DATA, + QuantizationArgs, +) from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, @@ -238,6 +242,55 @@ def process_kv_cache_config( return config +def is_attention_module(module: Module): + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + ) + + +def is_mlp_module(module: Module): + return "mlp" in module.__class__.__name__.lower() and ( + hasattr(module, "gate_proj") or hasattr(module, "up_porj") + ) + + +def update_fp4_global_scales(model): + for name, submodule in iter_named_quantizable_modules( + model, + include_attn=True, + include_mlp=True, + ): + if is_attention_module(submodule): + q_weight = submodule.q_proj.weight.data + v_weight = submodule.v_proj.weight.data + k_weight = submodule.k_proj.weight.data + all_data = torch.cat((q_weight, v_weight, k_weight), dim=0) + + scale_dtype = FP8_E4M3_DATA.dtype + tensor_amax = torch.abs(all_data.data).max().to(torch.float32) + value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax + value = value.to(torch.float32) + + update_parameter_data(submodule.q_proj, value, "weight_global_scale") + update_parameter_data(submodule.k_proj, value, "weight_global_scale") + update_parameter_data(submodule.v_proj, value, "weight_global_scale") + + if is_mlp_module(submodule): + gate_data = submodule.gate_proj.weight.data + up_data = submodule.up_proj.weight.data + all_data = torch.cat((gate_data, up_data), dim=0) + + scale_dtype = FP8_E4M3_DATA.dtype + tensor_amax = torch.abs(all_data.data).max().to(torch.float32) + value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax + value = value.to(torch.float32) + + update_parameter_data(submodule.gate_proj, value, "weight_global_scale") + update_parameter_data(submodule.up_proj, value, "weight_global_scale") + + def apply_quantization_status(model: Module, status: QuantizationStatus): """ Applies in place the quantization lifecycle up to the given status @@ -266,6 +319,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): ) ) + # hacks + if status == QuantizationStatus.INITIALIZED: + update_fp4_global_scales(model) + if current_status < status >= QuantizationStatus.COMPRESSED > current_status: model.apply(compress_quantized_weights) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index c9429231..fcafdcbc 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -274,7 +274,10 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None def iter_named_quantizable_modules( - model: Module, include_children: bool = True, include_attn: bool = False + model: Module, + include_children: bool = True, + include_attn: bool = False, + include_mlp: bool = False, ) -> Generator[Tuple[str, Module], None, None]: """ Yield name and submodule of @@ -307,6 +310,9 @@ def iter_named_quantizable_modules( if include_attn: if name.endswith("self_attn"): yield name, submodule + if include_mlp: + if name.endswith("mlp"): + yield name, submodule def get_torch_bit_depth(value: torch.Tensor) -> int: