From 66368fbb8ba03a107f745035849318304970de45 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Jul 2025 09:28:57 -0400 Subject: [PATCH 1/5] wip: initial implementation Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 45 +------ .../quantization/lifecycle/forward.py | 48 ++++++++ .../quantization/lifecycle/initialize.py | 115 ++++++++---------- .../transform/transform_args.py | 9 +- 4 files changed, 105 insertions(+), 112 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c8dbeced..7a5e5334 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -140,7 +140,6 @@ def apply_quantization_config( # build mapping of targets to schemes for easier matching # use ordered dict to preserve target ordering in config target_to_scheme = OrderedDict() - config = process_quantization_config(config) names_to_scheme = dict() for scheme in config.config_groups.values(): for target in scheme.targets: @@ -152,13 +151,7 @@ def apply_quantization_config( # list of submodules to ignore ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in iter_named_quantizable_modules( - model, - include_children=True, - include_attn=True, - ): # child modules and attention modules - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) + for name, submodule in model.named_modules(): if matches := find_name_or_class_matches(name, submodule, config.ignore): for match in matches: ignored_submodules[match].append(name) @@ -200,42 +193,6 @@ def apply_quantization_config( return names_to_scheme -def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: - """ - Preprocess the raw QuantizationConfig - - :param config: the raw QuantizationConfig - :return: the processed QuantizationConfig - """ - if config.kv_cache_scheme is not None: - config = process_kv_cache_config(config) - - return config - - -def process_kv_cache_config( - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS -) -> QuantizationConfig: - """ - Reformulate the `config.kv_cache` as a `config_group` - and add it to the set of existing `config.groups` - - :param config: the QuantizationConfig - :return: the QuantizationConfig with additional "kv_cache" group - """ - if targets == KV_CACHE_TARGETS: - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") - - kv_cache_dict = config.kv_cache_scheme.model_dump() - kv_cache_scheme = QuantizationScheme( - output_activations=QuantizationArgs(**kv_cache_dict), - targets=targets, - ) - kv_cache_group = dict(kv_cache=kv_cache_scheme) - config.config_groups.update(kv_cache_group) - return config - - def apply_quantization_status(model: Module, status: QuantizationStatus): """ Applies in place the quantization lifecycle up to the given status diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b4ca3a82..a70616fc 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -31,6 +31,7 @@ ) from compressed_tensors.utils import safe_permute from torch.nn import Module +from transformers import AttentionInterface __all__ = [ @@ -42,6 +43,53 @@ ] +def calibrated_attention( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + from compressed_tensors.transform import TransformBase, TransformLocation + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + # apply transforms first + for submodule in module.children(): + if isinstance(submodule, TransformBase): + if TransformBase.args.location == TransformLocation.ATTN_Q: + query = submodule(query) + + if TransformBase.args.location == TransformLocation.ATTN_K: + key = submodule(key) + + # if TransformBase.args.location == TransformLocation.ATTN_V: + # key = submodule(key) + + # apply activation quantization second + scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) + if scheme is not None: + if scheme.input_activations is not None: + query = forward_quantize(module, query, "q", scheme.input_activations) + key = forward_quantize(module, key, "k", scheme.input_activations) + value = forward_quantize(module, value, "v", scheme.input_activations) + + if scheme.weights is not None: + raise ValueError("") + + if scheme.output_activations is not None: + raise NotImplementedError("") + + return ALL_ATTENTION_FUNCTIONS["eager"]( + module, query, key, value, attention_mask, scaling, dropout, **kwargs + ) + + +AttentionInterface.register("calibrated_attention", calibrated_attention) + + @torch.no_grad() def quantize( x: torch.Tensor, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 806a98f0..2b8f888b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -36,24 +36,17 @@ get_execution_device, register_offload_parameter, ) +from compressed_tensors.utils.helpers import patch_attr from torch.nn import Module, Parameter +from transformers.configuration_utils import PretrainedConfig __all__ = [ "initialize_module_for_quantization", "is_attention_module", - "KVCacheScaleType", ] -_LOGGER = logging.getLogger(__name__) - - -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, @@ -78,17 +71,40 @@ def initialize_module_for_quantization( # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return + # initialize scheme and status + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + if is_attention_module(module): - # quantized actions based on calltime status - _initialize_attn_scales(module) + assert scheme.input_activations is not None + for base_name in ("q", "k", "v"): + _initialize_quantization_parameters( + module, + base_name, + scheme.input_activations, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) - else: + # wrap attention interface + config = getattr(module, "config") + original_forward = module.forward + assert isinstance(config, PretrainedConfig) and hasattr( + config, "_attn_implementation" + ) + + def wrapped_forward(self, *args, **kwargs): + with patch_attr(config, "_attn_implementation", "calibrated_attention"): + return original_forward(*args, **kwargs) + module.forward = wrapped_forward + return + + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): if scheme.input_activations is not None: - _initialize_scale_zero_point( + _initialize_quantization_parameters( module, "input", scheme.input_activations, @@ -97,41 +113,34 @@ def initialize_module_for_quantization( ) if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - scale_dtype=scale_dtype, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) + _initialize_quantization_parameters( + module, + "weight", + scheme.weights, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations, scale_dtype=scale_dtype - ) - - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED + _initialize_quantization_parameters( + module, + "output", + scheme.output_activations, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) with disable_hf_hook(module): # wrap forward call of module to perform # quantized actions based on calltime status wrap_module_forward_quantized(module, scheme) + else: + raise ValueError(f"Unsupported quantization target {type(module)}") + def is_attention_module(module: Module): + # can redefine to inspect source code for references to ALL_ATTENTION_FUNCTIONS return "attention" in module.__class__.__name__.lower() and ( hasattr(module, "k_proj") or hasattr(module, "v_proj") @@ -139,11 +148,10 @@ def is_attention_module(module: Module): ) -def _initialize_scale_zero_point( +def _initialize_quantization_parameters( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, force_zero_point: bool = True, scale_dtype: Optional[torch.dtype] = None, ): @@ -170,7 +178,8 @@ def _initialize_scale_zero_point( else: expected_shape = 1 - if base_name == "weight" and weight_shape is not None: + if base_name == "weight": + weight_shape = getattr(module, "weight").shape if quantization_args.strategy == QuantizationStrategy.CHANNEL: # (output_channels, 1) expected_shape = (weight_shape[0], 1) @@ -218,25 +227,3 @@ def _initialize_scale_zero_point( requires_grad=False, ) register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) - - -def _initialize_attn_scales(module: Module) -> None: - """Initlaize k_scale, v_scale for self_attn""" - - expected_shape = 1 # per tensor - - param = next(module.parameters()) - scale_dtype = param.dtype - device = param.device - - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) - - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index 16ab10b3..8b114812 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -33,8 +33,8 @@ class TransformLocation(str, Enum): | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 - | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 - | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 + | `ATTN_Q` | online | query_states | `this.ATTN_K` | # noqa: E501 + | `ATTN_K` | online | key_states | `this.Q_ATTN` | # noqa: E501 | -------------------------------------------------------------------------------------------------------- | # noqa: E501 """ @@ -42,8 +42,9 @@ class TransformLocation(str, Enum): WEIGHT_INPUT = "weight_input" WEIGHT_OUTPUT = "weight_output" OUTPUT = "output" - K_CACHE = "k_cache" - Q_ATTN = "q_attn" + ATTN_Q = "attn_q" + ATTN_K = "attn_k" + # ATTN_V = "attn_v" class TransformArgs(BaseModel): From f5fb6464ea42b3db3406a0edb0496cebea98d02e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Jul 2025 10:35:11 -0400 Subject: [PATCH 2/5] move calibrated_attention to lc Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/forward.py | 47 ------------------- .../quantization/lifecycle/initialize.py | 36 +++++--------- 2 files changed, 11 insertions(+), 72 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index a70616fc..ea3aba9f 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -43,53 +43,6 @@ ] -def calibrated_attention( - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - from compressed_tensors.transform import TransformBase, TransformLocation - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - - # apply transforms first - for submodule in module.children(): - if isinstance(submodule, TransformBase): - if TransformBase.args.location == TransformLocation.ATTN_Q: - query = submodule(query) - - if TransformBase.args.location == TransformLocation.ATTN_K: - key = submodule(key) - - # if TransformBase.args.location == TransformLocation.ATTN_V: - # key = submodule(key) - - # apply activation quantization second - scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) - if scheme is not None: - if scheme.input_activations is not None: - query = forward_quantize(module, query, "q", scheme.input_activations) - key = forward_quantize(module, key, "k", scheme.input_activations) - value = forward_quantize(module, value, "v", scheme.input_activations) - - if scheme.weights is not None: - raise ValueError("") - - if scheme.output_activations is not None: - raise NotImplementedError("") - - return ALL_ATTENTION_FUNCTIONS["eager"]( - module, query, key, value, attention_mask, scaling, dropout, **kwargs - ) - - -AttentionInterface.register("calibrated_attention", calibrated_attention) - - @torch.no_grad() def quantize( x: torch.Tensor, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2b8f888b..bafef14d 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -77,31 +77,6 @@ def initialize_module_for_quantization( module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED - if is_attention_module(module): - assert scheme.input_activations is not None - for base_name in ("q", "k", "v"): - _initialize_quantization_parameters( - module, - base_name, - scheme.input_activations, - force_zero_point=force_zero_point, - scale_dtype=scale_dtype, - ) - - # wrap attention interface - config = getattr(module, "config") - original_forward = module.forward - assert isinstance(config, PretrainedConfig) and hasattr( - config, "_attn_implementation" - ) - - def wrapped_forward(self, *args, **kwargs): - with patch_attr(config, "_attn_implementation", "calibrated_attention"): - return original_forward(*args, **kwargs) - - module.forward = wrapped_forward - return - if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): if scheme.input_activations is not None: _initialize_quantization_parameters( @@ -135,6 +110,17 @@ def wrapped_forward(self, *args, **kwargs): # quantized actions based on calltime status wrap_module_forward_quantized(module, scheme) + elif is_attention_module(module): + assert scheme.input_activations is not None + for base_name in ("q", "k", "v"): + _initialize_quantization_parameters( + module, + base_name, + scheme.input_activations, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) + else: raise ValueError(f"Unsupported quantization target {type(module)}") From 8b61558e1ad8f18ec55392bac4c909711895a25b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Jul 2025 14:20:33 -0400 Subject: [PATCH 3/5] better zero point resolution Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index bafef14d..a3a6c918 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -177,7 +177,13 @@ def _initialize_quantization_parameters( expected_shape = (weight_shape[0], max(num_groups, 1)) # 3. Identify quantization scale and zp dtype - scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype + if scale_dtype is None: + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + scale_dtype = module.weight.dtype + elif is_attention_module(module): + scale_dtype = next(module.parameters()).dtype + else: + raise ValueError() if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype From 2daa566df50cfe6496c6febd252f2e3288b44904 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Jul 2025 15:04:01 -0400 Subject: [PATCH 4/5] add get_calibrated_locations Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index a3a6c918..b156cf23 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -13,10 +13,8 @@ # limitations under the License. -import logging import math -from enum import Enum -from typing import List, Optional +from typing import Optional, Tuple import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -30,20 +28,19 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme +from compressed_tensors.quantization.utils import is_fp4 from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, register_offload_parameter, ) -from compressed_tensors.utils.helpers import patch_attr from torch.nn import Module, Parameter -from transformers.configuration_utils import PretrainedConfig __all__ = [ "initialize_module_for_quantization", "is_attention_module", + "get_calibrated_locations", ] @@ -77,8 +74,10 @@ def initialize_module_for_quantization( module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED + input, weight, output = get_calibrated_locations(scheme) + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): - if scheme.input_activations is not None: + if input: _initialize_quantization_parameters( module, "input", @@ -87,7 +86,7 @@ def initialize_module_for_quantization( scale_dtype=scale_dtype, ) - if scheme.weights is not None: + if weight: _initialize_quantization_parameters( module, "weight", @@ -96,7 +95,7 @@ def initialize_module_for_quantization( scale_dtype=scale_dtype, ) - if scheme.output_activations is not None: + if output: _initialize_quantization_parameters( module, "output", @@ -111,7 +110,7 @@ def initialize_module_for_quantization( wrap_module_forward_quantized(module, scheme) elif is_attention_module(module): - assert scheme.input_activations is not None + assert input and scheme.input_activations is not None for base_name in ("q", "k", "v"): _initialize_quantization_parameters( module, @@ -141,9 +140,6 @@ def _initialize_quantization_parameters( force_zero_point: bool = True, scale_dtype: Optional[torch.dtype] = None, ): - if quantization_args.dynamic is True: - return - # initialize on execution device to avoid performing quantized ops on cpu device = get_execution_device(module) @@ -196,7 +192,7 @@ def _initialize_quantization_parameters( # 4. Initializes empty scale, zero point, and g_idx parameters for the module # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: + if quantization_args.dynamic is False: init_scale = Parameter( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, @@ -219,3 +215,11 @@ def _initialize_quantization_parameters( requires_grad=False, ) register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + +def get_calibrated_locations(scheme: QuantizationScheme) -> Tuple[bool, bool, bool]: + input = scheme.input_activations and scheme.input_activations.dynamic is not True + weight = scheme.weights is not None + output = scheme.output_activations and scheme.output_activations.dynamic is not True + + return input, weight, output From 8635c1c340e1bb2bd4cc1638e3eb956478137d0f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Jul 2025 17:58:59 -0400 Subject: [PATCH 5/5] add attention preset schemes Signed-off-by: Kyle Sayers --- .../quantization/quant_scheme.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 753afc6c..afec6c7d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -243,6 +243,33 @@ def is_preset_scheme(name: str) -> bool: ), ) +# FP8 attention quantization +FP8_ATTN = dict( + targets=["re:.*self_attn$"], + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TOKEN, + symmetric=True, + dynamic=False, + observer=None, + ), +) + +# FP4 attention quantization +NVFP4_ATTN = dict( + targets=["re:.*self_attn$"], + input_activations=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR_GROUP, + symmetric=True, + dynamic=DynamicType.LOCAL, + group_size=16, + ), +) + + PRESET_SCHEMES = { # Unquantized (no-op) "UNQUANTIZED": UNQUANTIZED, @@ -259,4 +286,7 @@ def is_preset_scheme(name: str) -> bool: "FP8_DYNAMIC": FP8_DYNAMIC, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, + # Attention activation schemes + "FP8_ATTN": FP8_ATTN, + "NVFP4_ATTN": NVFP4_ATTN, }