From 1a7daf83cc12f7bccd327e1e9a3a26d75fdd6d32 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 22 Jun 2025 20:29:27 -0400 Subject: [PATCH] wip Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 79 +++++++++++-------- .../quantization/quant_args.py | 6 ++ 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 806a98f0..6933b257 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -153,17 +153,6 @@ def _initialize_scale_zero_point( # initialize on execution device to avoid performing quantized ops on cpu device = get_execution_device(module) - # 1. Create global_scales for tensor_group - generates - # a per tensor scale - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: - init_global_scale = Parameter( - torch.empty(1, dtype=torch.float32, device=device), - requires_grad=False, - ) - register_offload_parameter( - module, f"{base_name}_global_scale", init_global_scale - ) - # 2. Infer expected scale/zero point shape if quantization_args.strategy == QuantizationStrategy.TOKEN: expected_shape = (1, 1) @@ -182,42 +171,64 @@ def _initialize_scale_zero_point( expected_shape = (weight_shape[0], max(num_groups, 1)) # 3. Identify quantization scale and zp dtype - scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype + unquantized_dtype = module.weight.dtype + quantized_dtype = quantization_args.pytorch_dtype() - if is_fp4(quantization_args=quantization_args): - scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype - else: - # TODO: consider erroring out in the future as if the dtype if not one of these, - # there is likely bug - if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: - scale_dtype = torch.float16 - zp_dtype = quantization_args.pytorch_dtype() + scale_dtype = scale_dtype if scale_dtype is not None else unquantized_dtype + scale_value = quantized_dtype.to(scale_dtype) / unquantized_dtype.to(scale_dtype) + zp_dtype = quantized_dtype + + # fp4 is a special case where the qparams are stored in FP8 type + # note that zp is not supported and throws error during qargs validation + if is_fp4(quantization_args): + scale_dtype = FP8_E4M3_DATA.dtype + + # for tensor group quantization (fp4), initialize a global scale + if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + register_offload_parameter( + module, + f"{base_name}_global_scale", + Parameter( + torch.empty(1, dtype=torch.float32, device=device), + requires_grad=False, + ) + ) # 4. Initializes empty scale, zero point, and g_idx parameters for the module # do not init scales for quantzation_args.dynamic == DynamicType.local if not quantization_args.dynamic: - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, + register_offload_parameter( + module, + f"{base_name}_scale", + Parameter( + torch.full( + expected_shape, scale_value, dtype=scale_dtype, device=device + ), + requires_grad=False, + ) ) - register_offload_parameter(module, f"{base_name}_scale", init_scale) + # zero points if force_zero_point or not quantization_args.symmetric: - init_zero_point = Parameter( - torch.zeros(expected_shape, device=device, dtype=zp_dtype), - requires_grad=False, + register_offload_parameter( + module, + f"{base_name}_zero_point", + Parameter( + torch.zeros(expected_shape, device=device, dtype=zp_dtype), + requires_grad=False, + ) ) - register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) # only grouped activation ordering has g_idx if quantization_args.actorder == ActivationOrdering.GROUP: - g_idx_shape = (weight_shape[1],) - g_idx_dtype = torch.int - init_g_idx = Parameter( - torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), - requires_grad=False, + register_offload_parameter( + module, + f"{base_name}_g_idx", + Parameter( + torch.arange(weight_shape[1], device=device, dtype=torch.int32), + requires_grad=False, + ) ) - register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) def _initialize_attn_scales(module: Module) -> None: diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index fdf34a28..d0981faf 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -19,6 +19,7 @@ import torch from compressed_tensors.utils import Aliasable from compressed_tensors.utils.helpers import deprecated +from compressed_tensors.quantization.utils import is_fp4 from pydantic import BaseModel, Field, field_validator, model_validator @@ -310,6 +311,11 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": # default to minmax for non-dynamic cases observer = "minmax" + # validate fp4 + if is_fp4(model) and model.symmetric: + raise NotImplementedError("FP4 asymmetric quantization is not supported") + + # write back modified values model.strategy = strategy model.observer = observer