diff --git a/src/compressed_tensors/compressors/quantized_compressors/__init__.py b/src/compressed_tensors/compressors/quantized_compressors/__init__.py index 51e8b8e2..496519d4 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/__init__.py +++ b/src/compressed_tensors/compressors/quantized_compressors/__init__.py @@ -14,5 +14,6 @@ # flake8: noqa from .base import * +from .modelopt_quantized import * from .naive_quantized import * from .pack_quantized import * diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 098328be..16cdcd7c 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -113,6 +113,9 @@ def compress( scale = model_state.get(merge_names(prefix, "weight_scale"), None) zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) + global_scale = model_state.get( + merge_names(prefix, "weight_global_scale"), None + ) if scale is not None: # weight is quantized, compress it if isinstance(names_to_scheme[prefix], tuple): @@ -125,6 +128,7 @@ def compress( scale=scale, zero_point=zp, g_idx=g_idx, + global_scale=global_scale, quantization_args=quant_args, device="cpu", ) diff --git a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py new file mode 100644 index 00000000..aeb3ccca --- /dev/null +++ b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py @@ -0,0 +1,168 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, Optional, Tuple + +import numpy +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.quantized_compressors.base import ( + BaseQuantizationCompressor, +) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize +from torch import Tensor + + +__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"] + +FLOAT_TO_E2M1 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] + + +@BaseCompressor.register(name=CompressionFormat.modelopt_quantized.value) +class ModelOptCompressor(BaseQuantizationCompressor): + """ + Implements naive compression for quantized models. Weight of each + quantized layer is converted from its original float type to the closest Pytorch + type to the type specified by the layer's QuantizationArgs. + """ + + @property + def compression_param_names(self) -> Tuple[str]: + """ + Returns a tuple of compression parameter names introduced by + the compressor during compression + """ + return ( + "weight_packed", + "weight_scale", + "weight_zero_point", + "weight_global_scale", + ) + + def compress_weight( + self, + weight: Tensor, + scale: Tensor, + global_scale: Tensor, + quantization_args: QuantizationArgs, + device: Optional[torch.device] = None, + zero_point: Optional[torch.Tensor] = None, + g_idx: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + + quantized_weight = quantize( + x=weight, + scale=scale, + global_scale=global_scale, + zero_point=zero_point, + args=quantization_args, + ) + compressed_dict = {} + weight_packed = pack_fp4_to_uint8(quantized_weight) + if device is not None: + weight_packed = weight_packed.to(device) + compressed_dict["weight_packed"] = weight_packed + return compressed_dict + + def decompress_weight( + self, + compressed_data: Dict[str, Tensor], + quantization_args: Optional[QuantizationArgs] = None, + ) -> torch.Tensor: + + weight = compressed_data["weight_packed"] + scale = compressed_data["weight_scale"] + global_scale = compressed_data["weight_global_scale"] + m, n = weight.shape + # TODO: we may not always use the global_scale dtype as the detype to dequant + # We need to pass in the pretrained model dtype to the compressors + unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + decompressed_weight = dequantize( + x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype + ) + + return decompressed_weight + + +def pack_fp4_to_uint8(x: torch.Tensor): + m, n = x.shape + device = x.device + + # Create lookup table for FP4 values to indices + # Map the absolute values to 0-7 indices + kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype) + + # Find closest valid FP4 value index for each element + abs_x = torch.abs(x) + abs_indices = torch.zeros_like(abs_x, dtype=torch.long) + for i, val in enumerate(kE2M1): + abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices) + + # Apply sign bit (bit 3) to get final 4-bit representation + indices = abs_indices + (torch.signbit(x) * 8).to(torch.long) + + # Reshape to prepare for packing pairs of values + indices = indices.reshape(-1) + + # Handle odd length by padding if necessary + if indices.numel() % 2 != 0: + indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) + + # Reshape to pair consecutive elements + indices = indices.reshape(-1, 2) + + # Pack pairs of 4-bit values into 8-bit values + packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8) + + return packed.reshape(m, n // 2) + + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + +# reference: : https://github.com/vllm-project/vllm/pull/16362 +def unpack_fp4_from_uint8(a: torch.Tensor, m: int, n: int, dtype=torch.bfloat16): + assert a.dtype == torch.uint8 + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n).to(dtype=dtype) diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 9ca6f2cf..3ec3bc46 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -32,6 +32,7 @@ class CompressionFormat(Enum): naive_quantized = "naive-quantized" pack_quantized = "pack-quantized" marlin_24 = "marlin-24" + modelopt_quantized = "modelopt-quantized" @unique diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 0e6c3d5f..0d4aaae0 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -28,7 +28,11 @@ from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + FP4_E2M1_DATA, + FP8_E4M3_DATA, + QuantizationArgs, +) from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, @@ -238,6 +242,55 @@ def process_kv_cache_config( return config +def is_attention_module(module: Module): + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + ) + + +def is_mlp_module(module: Module): + return "mlp" in module.__class__.__name__.lower() and ( + hasattr(module, "gate_proj") or hasattr(module, "up_porj") + ) + + +def update_fp4_global_scales(model): + for name, submodule in iter_named_quantizable_modules( + model, + include_attn=True, + include_mlp=True, + ): + if is_attention_module(submodule): + q_weight = submodule.q_proj.weight.data + v_weight = submodule.v_proj.weight.data + k_weight = submodule.k_proj.weight.data + all_data = torch.cat((q_weight, v_weight, k_weight), dim=0) + + scale_dtype = FP8_E4M3_DATA.dtype + tensor_amax = torch.abs(all_data.data).max().to(torch.float32) + value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax + value = value.to(torch.float32) + + update_parameter_data(submodule.q_proj, value, "weight_global_scale") + update_parameter_data(submodule.k_proj, value, "weight_global_scale") + update_parameter_data(submodule.v_proj, value, "weight_global_scale") + + if is_mlp_module(submodule): + gate_data = submodule.gate_proj.weight.data + up_data = submodule.up_proj.weight.data + all_data = torch.cat((gate_data, up_data), dim=0) + + scale_dtype = FP8_E4M3_DATA.dtype + tensor_amax = torch.abs(all_data.data).max().to(torch.float32) + value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax + value = value.to(torch.float32) + + update_parameter_data(submodule.gate_proj, value, "weight_global_scale") + update_parameter_data(submodule.up_proj, value, "weight_global_scale") + + def apply_quantization_status(model: Module, status: QuantizationStatus): """ Applies in place the quantization lifecycle up to the given status @@ -266,6 +319,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): ) ) + # hacks + if status == QuantizationStatus.INITIALIZED: + update_fp4_global_scales(model) + if current_status < status >= QuantizationStatus.COMPRESSED > current_status: model.apply(compress_quantized_weights) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index f4f93f27..9fc3a68d 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -20,6 +20,7 @@ from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, + QuantizationType, round_to_quantized_type, ) from compressed_tensors.quantization.quant_config import QuantizationStatus @@ -49,6 +50,7 @@ def quantize( args: QuantizationArgs, dtype: Optional[torch.dtype] = None, g_idx: Optional[torch.Tensor] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Quantize the input tensor x using the QuantizationStrategy specified in args. @@ -75,6 +77,7 @@ def quantize( do_quantize=True, do_dequantize=False, g_idx=g_idx, + global_scale=global_scale, ) @@ -86,6 +89,7 @@ def dequantize( args: Optional[QuantizationArgs] = None, dtype: Optional[torch.dtype] = None, g_idx: Optional[torch.Tensor] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Dequantize a quantized input tensor x_q based on the strategy specified in args. If @@ -128,6 +132,7 @@ def dequantize( do_dequantize=True, dtype=dtype, g_idx=g_idx, + global_scale=global_scale, ) @@ -138,6 +143,7 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Fake quantize the input tensor x by quantizing then dequantizing with @@ -161,6 +167,7 @@ def fake_quantize( do_quantize=True, do_dequantize=True, g_idx=g_idx, + global_scale=global_scale, ) @@ -174,6 +181,7 @@ def _process_quantization( dtype: Optional[torch.dtype] = None, do_quantize: bool = True, do_dequantize: bool = True, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: q_min, q_max = calculate_range(args, x.device) group_size = args.group_size @@ -221,18 +229,21 @@ def _process_quantization( end = start + group_count if do_quantize: output[:, start:end] = _quantize( - x[:, start:end], - sc, - zp, - q_min, - q_max, - args, + x=x[:, start:end], + scale=sc, + zero_point=zp, + q_min=q_min, + q_max=q_max, + args=args, dtype=dtype, + global_scale=global_scale, ) if do_dequantize: input = output[:, start:end] if do_quantize else x[:, start:end] - output[:, start:end] = _dequantize(input, sc, zp) + output[:, start:end] = _dequantize( + x_q=input, scale=sc, zero_point=zp, global_scale=global_scale + ) if not is_column_order: output = safe_permute(output, torch.argsort(perm), dim=1) @@ -240,16 +251,22 @@ def _process_quantization( else: # covers channel, token and tensor strategies if do_quantize: output = _quantize( - x, - scale, - zero_point, - q_min, - q_max, - args, + x=x, + scale=scale, + zero_point=zero_point, + q_min=q_min, + q_max=q_max, + args=args, dtype=dtype, + global_scale=global_scale, ) if do_dequantize: - output = _dequantize(output if do_quantize else x, scale, zero_point) + output = _dequantize( + output if do_quantize else x, + scale=scale, + zero_point=zero_point, + global_scale=global_scale, + ) return output @@ -330,6 +347,7 @@ def forward_quantize( return value g_idx = getattr(module, "weight_g_idx", None) + global_scale = getattr(module, f"{base_name}_global_scale", None) if args.dynamic: # dynamic quantization - determine the scale/zp on the fly @@ -345,6 +363,7 @@ def forward_quantize( zero_point=zero_point, args=args, g_idx=g_idx, + global_scale=global_scale, ) @@ -357,11 +376,16 @@ def _quantize( q_max: torch.Tensor, args: QuantizationArgs, dtype: Optional[torch.dtype] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if global_scale: + scale = scale.to(global_scale.dtype) / global_scale + scaled = x / scale if zero_point is not None: scaled += zero_point.to(x.dtype) + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) clamped_value = torch.clamp( scaled, @@ -381,7 +405,12 @@ def _dequantize( scale: torch.Tensor, zero_point: torch.Tensor = None, dtype: Optional[torch.dtype] = None, + global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + if global_scale: + scale = scale.to(global_scale.dtype) / global_scale + dequant_value = x_q.to(scale.dtype) if zero_point is not None: diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2b2ecf98..9f305b34 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -22,9 +22,12 @@ wrap_module_forward_quantized, ) from compressed_tensors.quantization.quant_args import ( + FP4_E2M1_DATA, + FP8_E4M3_DATA, ActivationOrdering, QuantizationArgs, QuantizationStrategy, + QuantizationType, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -169,10 +172,39 @@ def _initialize_scale_zero_point( expected_shape = (weight_shape[0], max(num_groups, 1)) scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype + + # NVFP4 support; use FP8 scales + # For weight quant, attach global scales for NVFP4 + if ( + quantization_args.num_bits == 4 + and quantization_args.type == QuantizationType.FLOAT + ): + if base_name == "weight": + scale_dtype = FP8_E4M3_DATA.dtype + # create and attach nvfp4 data + tensor_amax = torch.abs(module.weight.data).max().to(torch.float32) + # Setting data for now - could possibly be handled later in the pipeline + value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax + # TODO: use model.weight.dtype after checking + value = value.to(torch.float32).to(device) + # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32? + init_global_scale = Parameter(value, requires_grad=False) + register_offload_parameter( + module, f"{base_name}_global_scale", init_global_scale + ) + else: + # input scales should be float32 + scale_dtype = torch.float32 + # TODO: consider erroring out in the future as if the dtype if not one fo these, # there is likely bug - if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: + if scale_dtype not in [ + torch.float16, + torch.bfloat16, + torch.float32, + FP8_E4M3_DATA.dtype, + ]: scale_dtype = torch.float16 # initializes empty scale, zero point, and g_idx parameters for the module @@ -183,7 +215,14 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: - zp_dtype = quantization_args.pytorch_dtype() + if ( + quantization_args.num_bits == 4 + and quantization_args.type == QuantizationType.FLOAT + ): + zp_dtype = FP8_E4M3_DATA.dtype + else: + zp_dtype = quantization_args.pytorch_dtype() + init_zero_point = Parameter( torch.zeros(expected_shape, device=device, dtype=zp_dtype), requires_grad=False, diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 69c289d2..12d7f72f 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -13,6 +13,7 @@ # limitations under the License. import warnings +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, Union @@ -24,6 +25,8 @@ __all__ = [ "FP8_DTYPE", + "FP8_E4M3_DATA", + "FP4_E2M1_DATA", "QuantizationType", "QuantizationStrategy", "QuantizationArgs", @@ -31,8 +34,47 @@ "ActivationOrdering", ] + +@dataclass +class FloatArgs: + exponent: int + mantissa: int + bits: int + max: float + min: float + dtype: Optional[torch.dtype] = None + + +@dataclass +class FloatArgsFP4E2M1(FloatArgs): + def cast_to_fp4(self, x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +# TODO: Remove soon in favour of a more descriptive FloatArgs FP8_DTYPE = torch.float8_e4m3fn +FP8_E4M3_DATA = FloatArgs( + exponent=4, + mantissa=3, + bits=8, + max=torch.finfo(torch.float8_e4m3fn).max, + min=torch.finfo(torch.float8_e4m3fn).min, + dtype=torch.float8_e4m3fn, +) + +FP4_E2M1_DATA = FloatArgsFP4E2M1(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0) + class QuantizationType(str, Enum): """ @@ -233,8 +275,14 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: return model def pytorch_dtype(self) -> torch.dtype: + # TODO: required for the compressor + # Add FP4_nvfp4 type when updating naive_compressor if self.type == QuantizationType.FLOAT: - return FP8_DTYPE + if self.num_bits == 8: + return FP8_E4M3_DATA.dtype + else: + assert self.num_bits == 4 + raise NotImplementedError("Not supported for FP4") elif self.type == QuantizationType.INT: if self.num_bits <= 8: return torch.int8 @@ -263,7 +311,11 @@ def round_to_quantized_type( """ original_dtype = tensor.dtype if args.type == QuantizationType.FLOAT: - rounded = tensor.to(FP8_DTYPE) + if args.num_bits == 8: + rounded = tensor.to(FP8_E4M3_DATA.dtype) + else: + assert args.num_bits == 4 + rounded = FP4_E2M1_DATA.cast_to_fp4(tensor) elif args.type == QuantizationType.INT: rounded = torch.round(tensor) else: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 9fcc0d55..3d5c9a76 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -100,6 +100,17 @@ def is_preset_scheme(name: str) -> bool: UNQUANTIZED = dict() +NVFP4 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=16, + ) +) + # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( @@ -225,4 +236,5 @@ def is_preset_scheme(name: str) -> bool: # Float weight and activation schemes "FP8": FP8, "FP8_DYNAMIC": FP8_DYNAMIC, + "NVFP4": NVFP4, } diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index d7e6d5f8..fcafdcbc 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -17,7 +17,8 @@ import torch from compressed_tensors.quantization.quant_args import ( - FP8_DTYPE, + FP4_E2M1_DATA, + FP8_E4M3_DATA, QuantizationArgs, QuantizationStrategy, QuantizationType, @@ -54,7 +55,10 @@ def calculate_qparams( - min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs + min_vals: Tensor, + max_vals: Tensor, + quantization_args: QuantizationArgs, + global_scale: Optional[Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: """ :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) @@ -73,12 +77,33 @@ def calculate_qparams( bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min - zp_dtype = quantization_args.pytorch_dtype() + # TODO: update + # zp_dtype = quantization_args.pytorch_dtype() + zp_dtype = FP8_E4M3_DATA.dtype if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + + if ( + quantization_args.num_bits == 4 + and quantization_args.type == QuantizationType.FLOAT + and global_scale is not None + ): + scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) # Not needed + scales = scales.to(FP8_E4M3_DATA.dtype) + else: + # Divide over bit range over max value? + scales = max_val_pos / (float(bit_range) / 2) + + if scales.dtype == FP8_E4M3_DATA.dtype: + # use the next largest fp8 value from 0 + # Optionally, we swap to use the reciporcal + scales = torch.where( + scales == 0, torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype), scales + ) + else: + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: scales = (max_vals - min_vals) / float(bit_range) @@ -144,14 +169,14 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: q_max = torch.tensor(bit_range / 2 - 1, device=device) q_min = torch.tensor(-bit_range / 2, device=device) elif quantization_args.type == QuantizationType.FLOAT: - if quantization_args.num_bits != 8: - raise ValueError( - "Floating point quantization is only supported for 8 bits," - f"got {quantization_args.num_bits}" - ) - fp_range_info = torch.finfo(FP8_DTYPE) - q_max = torch.tensor(fp_range_info.max, device=device) - q_min = torch.tensor(fp_range_info.min, device=device) + if quantization_args.num_bits == 8: + q_max = torch.tensor(FP8_E4M3_DATA.max, device=device) + q_min = torch.tensor(FP8_E4M3_DATA.min, device=device) + else: + # nvfp4 ranges + assert quantization_args.num_bits == 4 + q_max = torch.tensor(FP4_E2M1_DATA.max, device=device) + q_min = torch.tensor(FP4_E2M1_DATA.min, device=device) else: raise ValueError(f"Invalid quantization type {quantization_args.type}") @@ -249,7 +274,10 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None def iter_named_quantizable_modules( - model: Module, include_children: bool = True, include_attn: bool = False + model: Module, + include_children: bool = True, + include_attn: bool = False, + include_mlp: bool = False, ) -> Generator[Tuple[str, Module], None, None]: """ Yield name and submodule of @@ -282,6 +310,9 @@ def iter_named_quantizable_modules( if include_attn: if name.endswith("self_attn"): yield name, submodule + if include_mlp: + if name.endswith("mlp"): + yield name, submodule def get_torch_bit_depth(value: torch.Tensor) -> int: diff --git a/tests/test_compressors/quantized_compressors/test_modelopt_quant.py b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py new file mode 100644 index 00000000..b5f81e67 --- /dev/null +++ b/tests/test_compressors/quantized_compressors/test_modelopt_quant.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.compressors.quantized_compressors.modelopt_quantized import ( + pack_fp4_to_uint8, + unpack_fp4_from_uint8, +) + + +def test_pack_unpack(): + x = torch.Tensor( + [ + [-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000, -0.0000], + [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, 0.0000], + [-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000], + [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000], + ] + ) + + dense_dtype = torch.bfloat16 + x = x.to(dense_dtype) + m, n = x.shape + packed = pack_fp4_to_uint8(x) + assert packed.dtype == torch.uint8 + unpacked = unpack_fp4_from_uint8(packed, m, n, dtype=dense_dtype) + assert unpacked.dtype == dense_dtype + + assert torch.equal(unpacked, x) # misleading as -0 and 0 are considered equal + sign_bitx = torch.signbit(x) + sign_bitout = torch.signbit(unpacked) + assert torch.equal(sign_bitout, sign_bitx)