diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c8dbeced..7a5e5334 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -140,7 +140,6 @@ def apply_quantization_config( # build mapping of targets to schemes for easier matching # use ordered dict to preserve target ordering in config target_to_scheme = OrderedDict() - config = process_quantization_config(config) names_to_scheme = dict() for scheme in config.config_groups.values(): for target in scheme.targets: @@ -152,13 +151,7 @@ def apply_quantization_config( # list of submodules to ignore ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in iter_named_quantizable_modules( - model, - include_children=True, - include_attn=True, - ): # child modules and attention modules - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) + for name, submodule in model.named_modules(): if matches := find_name_or_class_matches(name, submodule, config.ignore): for match in matches: ignored_submodules[match].append(name) @@ -200,42 +193,6 @@ def apply_quantization_config( return names_to_scheme -def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: - """ - Preprocess the raw QuantizationConfig - - :param config: the raw QuantizationConfig - :return: the processed QuantizationConfig - """ - if config.kv_cache_scheme is not None: - config = process_kv_cache_config(config) - - return config - - -def process_kv_cache_config( - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS -) -> QuantizationConfig: - """ - Reformulate the `config.kv_cache` as a `config_group` - and add it to the set of existing `config.groups` - - :param config: the QuantizationConfig - :return: the QuantizationConfig with additional "kv_cache" group - """ - if targets == KV_CACHE_TARGETS: - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") - - kv_cache_dict = config.kv_cache_scheme.model_dump() - kv_cache_scheme = QuantizationScheme( - output_activations=QuantizationArgs(**kv_cache_dict), - targets=targets, - ) - kv_cache_group = dict(kv_cache=kv_cache_scheme) - config.config_groups.update(kv_cache_group) - return config - - def apply_quantization_status(model: Module, status: QuantizationStatus): """ Applies in place the quantization lifecycle up to the given status diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b4ca3a82..ea3aba9f 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -31,6 +31,7 @@ ) from compressed_tensors.utils import safe_permute from torch.nn import Module +from transformers import AttentionInterface __all__ = [ diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 806a98f0..b156cf23 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -13,10 +13,8 @@ # limitations under the License. -import logging import math -from enum import Enum -from typing import List, Optional +from typing import Optional, Tuple import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -30,7 +28,7 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme +from compressed_tensors.quantization.utils import is_fp4 from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -42,18 +40,10 @@ __all__ = [ "initialize_module_for_quantization", "is_attention_module", - "KVCacheScaleType", + "get_calibrated_locations", ] -_LOGGER = logging.getLogger(__name__) - - -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, @@ -78,17 +68,17 @@ def initialize_module_for_quantization( # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return - if is_attention_module(module): - # quantized actions based on calltime status - _initialize_attn_scales(module) + # initialize scheme and status + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED - else: + input, weight, output = get_calibrated_locations(scheme) - if scheme.input_activations is not None: - _initialize_scale_zero_point( + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + if input: + _initialize_quantization_parameters( module, "input", scheme.input_activations, @@ -96,42 +86,46 @@ def initialize_module_for_quantization( scale_dtype=scale_dtype, ) - if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - scale_dtype=scale_dtype, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) - - if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations, scale_dtype=scale_dtype - ) - - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED + if weight: + _initialize_quantization_parameters( + module, + "weight", + scheme.weights, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) + + if output: + _initialize_quantization_parameters( + module, + "output", + scheme.output_activations, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) with disable_hf_hook(module): # wrap forward call of module to perform # quantized actions based on calltime status wrap_module_forward_quantized(module, scheme) + elif is_attention_module(module): + assert input and scheme.input_activations is not None + for base_name in ("q", "k", "v"): + _initialize_quantization_parameters( + module, + base_name, + scheme.input_activations, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) + + else: + raise ValueError(f"Unsupported quantization target {type(module)}") + def is_attention_module(module: Module): + # can redefine to inspect source code for references to ALL_ATTENTION_FUNCTIONS return "attention" in module.__class__.__name__.lower() and ( hasattr(module, "k_proj") or hasattr(module, "v_proj") @@ -139,17 +133,13 @@ def is_attention_module(module: Module): ) -def _initialize_scale_zero_point( +def _initialize_quantization_parameters( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, force_zero_point: bool = True, scale_dtype: Optional[torch.dtype] = None, ): - if quantization_args.dynamic is True: - return - # initialize on execution device to avoid performing quantized ops on cpu device = get_execution_device(module) @@ -170,7 +160,8 @@ def _initialize_scale_zero_point( else: expected_shape = 1 - if base_name == "weight" and weight_shape is not None: + if base_name == "weight": + weight_shape = getattr(module, "weight").shape if quantization_args.strategy == QuantizationStrategy.CHANNEL: # (output_channels, 1) expected_shape = (weight_shape[0], 1) @@ -182,7 +173,13 @@ def _initialize_scale_zero_point( expected_shape = (weight_shape[0], max(num_groups, 1)) # 3. Identify quantization scale and zp dtype - scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype + if scale_dtype is None: + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + scale_dtype = module.weight.dtype + elif is_attention_module(module): + scale_dtype = next(module.parameters()).dtype + else: + raise ValueError() if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype @@ -195,7 +192,7 @@ def _initialize_scale_zero_point( # 4. Initializes empty scale, zero point, and g_idx parameters for the module # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: + if quantization_args.dynamic is False: init_scale = Parameter( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, @@ -220,23 +217,9 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def _initialize_attn_scales(module: Module) -> None: - """Initlaize k_scale, v_scale for self_attn""" - - expected_shape = 1 # per tensor - - param = next(module.parameters()) - scale_dtype = param.dtype - device = param.device +def get_calibrated_locations(scheme: QuantizationScheme) -> Tuple[bool, bool, bool]: + input = scheme.input_activations and scheme.input_activations.dynamic is not True + weight = scheme.weights is not None + output = scheme.output_activations and scheme.output_activations.dynamic is not True - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) - - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) + return input, weight, output diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 753afc6c..afec6c7d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -243,6 +243,33 @@ def is_preset_scheme(name: str) -> bool: ), ) +# FP8 attention quantization +FP8_ATTN = dict( + targets=["re:.*self_attn$"], + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TOKEN, + symmetric=True, + dynamic=False, + observer=None, + ), +) + +# FP4 attention quantization +NVFP4_ATTN = dict( + targets=["re:.*self_attn$"], + input_activations=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR_GROUP, + symmetric=True, + dynamic=DynamicType.LOCAL, + group_size=16, + ), +) + + PRESET_SCHEMES = { # Unquantized (no-op) "UNQUANTIZED": UNQUANTIZED, @@ -259,4 +286,7 @@ def is_preset_scheme(name: str) -> bool: "FP8_DYNAMIC": FP8_DYNAMIC, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, + # Attention activation schemes + "FP8_ATTN": FP8_ATTN, + "NVFP4_ATTN": NVFP4_ATTN, } diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index 16ab10b3..8b114812 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -33,8 +33,8 @@ class TransformLocation(str, Enum): | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 - | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 - | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 + | `ATTN_Q` | online | query_states | `this.ATTN_K` | # noqa: E501 + | `ATTN_K` | online | key_states | `this.Q_ATTN` | # noqa: E501 | -------------------------------------------------------------------------------------------------------- | # noqa: E501 """ @@ -42,8 +42,9 @@ class TransformLocation(str, Enum): WEIGHT_INPUT = "weight_input" WEIGHT_OUTPUT = "weight_output" OUTPUT = "output" - K_CACHE = "k_cache" - Q_ATTN = "q_attn" + ATTN_Q = "attn_q" + ATTN_K = "attn_k" + # ATTN_V = "attn_v" class TransformArgs(BaseModel):