diff --git a/examples/attention_quantization/llama3_example.py b/examples/attention_quantization/llama3_example.py new file mode 100644 index 000000000..610aaf0fa --- /dev/null +++ b/examples/attention_quantization/llama3_example.py @@ -0,0 +1,53 @@ +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.transformers import oneshot +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +model_id = "meta-llama/Llama-3.2-1B-instruct" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = QuantizationModifier( + config_groups={ + "attention_quant": QuantizationScheme( + targets=["re:.*self_attn$"], + input_activations=QuantizationArgs(num_bits=8, type="float"), + ), + }, + ignore=["lm_head"], +) + +# Apply algorithms. +oneshot( + model=model, + dataset="ultrachat_200k", + splits={"calibration": "test_sft[:512]"}, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/__init__.py b/src/llmcompressor/modifiers/quantization/__init__.py index f1cdf596c..226869f39 100644 --- a/src/llmcompressor/modifiers/quantization/__init__.py +++ b/src/llmcompressor/modifiers/quantization/__init__.py @@ -1,5 +1,4 @@ # flake8: noqa -from .cache import * from .gptq import * from .quantization import * diff --git a/src/llmcompressor/modifiers/quantization/attention.py b/src/llmcompressor/modifiers/quantization/attention.py new file mode 100644 index 000000000..b8f6c805d --- /dev/null +++ b/src/llmcompressor/modifiers/quantization/attention.py @@ -0,0 +1,58 @@ +from typing import Optional + +import torch +from compressed_tensors.quantization import ( + QuantizationScheme, + QuantizationStatus, + forward_quantize, +) +from compressed_tensors.transform import TransformBase, TransformLocation +from transformers.modeling_utils import AttentionInterface +from transformers.models.llama.modeling_llama import eager_attention_forward + +from llmcompressor.modifiers.quantization.calibration import calibrate_activations + + +def calibrated_attention( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + for submodule in module.children(): + if isinstance(submodule, TransformBase): + # 1. apply transforms + if TransformBase.args.location == TransformLocation.ATTN_Q: + query = submodule(query) + + if TransformBase.args.location == TransformLocation.ATTN_K: + key = submodule(key) + + # if TransformBase.args.location == TransformLocation.ATTN_V: + # value = submodule(value) + + scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) + status: Optional[QuantizationStatus] = getattr(module, "quantization_status", None) + if getattr(scheme, "input_activations", None) is not None: + # 2. calibrate quantization + if status == QuantizationStatus.CALIBRATION: + calibrate_activations(module, value=query, base_name="q") + calibrate_activations(module, value=key, base_name="k") + calibrate_activations(module, value=value, base_name="v") + + # 3. apply quantization + if status in (QuantizationStatus.CALIBRATION, QuantizationStatus.FROZEN): + query = forward_quantize(module, query, "q", scheme.input_activations) + key = forward_quantize(module, key, "k", scheme.input_activations) + value = forward_quantize(module, value, "v", scheme.input_activations) + + return eager_attention_forward( + module, query, key, value, attention_mask, scaling, dropout, **kwargs + ) + + +AttentionInterface.register("calibrated_attention", calibrated_attention) diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py deleted file mode 100644 index dd3640dda..000000000 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -from compressed_tensors.quantization.lifecycle import KVCacheScaleType -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import Tensor -from transformers import DynamicCache - -from llmcompressor.observers import Observer - - -class QuantizedKVParameterCache(DynamicCache): - """ - Quantized KV cache used in the forward call based on HF's dynamic cache. - Quantization strategy (tensor, group, channel) set from Quantization arg's strategy - Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quantize(), ._dequantize() - gets called appropriately. - The size of tensor is - `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - - - Triggered by adding kv_cache_scheme in the recipe. - - Example: - - ```python3 - recipe = ''' - quant_stage: - quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: 8 - type: float - strategy: tensor - dynamic: false - symmetric: true - ''' - - """ - - _instance = None - _initialized = False - - def __new__(cls, *args, **kwargs): - """Singleton""" - if cls._instance is None: - cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) - return cls._instance - - def __init__(self, quantization_args: QuantizationArgs): - if not self._initialized: - super().__init__() - - self.quantization_args = quantization_args - - self.k_observers: List[Observer] = [] - self.v_observers: List[Observer] = [] - - # each index corresponds to layer_idx of the attention layer - self.k_scales: List[Tensor] = [] - self.v_scales: List[Tensor] = [] - - self.k_zps: List[Tensor] = [] - self.v_zps: List[Tensor] = [] - - self._initialized = True - - def update( - self, - key_states: Tensor, - value_states: Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Get the k_scale and v_scale and output the - fakequant-ed key_states and value_states - """ - - if len(self.k_observers) <= layer_idx: - k_observer_name = self.quantization_args.observer - k_observer = Observer.load_from_registry( - k_observer_name, quantization_args=self.quantization_args - ) - v_observer_name = self.quantization_args.observer - v_observer = Observer.load_from_registry( - v_observer_name, quantization_args=self.quantization_args - ) - - # NOTE: User may ignore some layers in configuration, - # meaning len(self.k_observers) <= layer_idx-1 - # Must account for that case by padding list so that - # index of lists corresponds to layer_idx - _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) - _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) - - q_key_states = self._quantize( - key_states.contiguous(), KVCacheScaleType.KEY, layer_idx - ) - q_value_states = self._quantize( - value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx - ) - - qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) - qdq_value_states = self._dequantize( - q_value_states, KVCacheScaleType.VALUE, layer_idx - ) - - keys_to_return, values_to_return = qdq_key_states, qdq_value_states - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """ - Returns the sequence length of the cached states. - A layer index can be optionally passed. - """ - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and - # rely on `_seen_tokens` which is updated every "layer_idx" == 0, - # this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to - # verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def reset_states(self): - """reset the kv states (used in calibration)""" - self.key_cache: List[Tensor] = [] - self.value_cache: List[Tensor] = [] - # Used in `generate` to keep tally of how many tokens the cache has seen - self._seen_tokens = 0 - self._quantized_key_cache: List[Tensor] = [] - self._quantized_value_cache: List[Tensor] = [] - - def reset(self): - """ - Reset the instantiation, create new instance on init - """ - QuantizedKVParameterCache._instance = None - QuantizedKVParameterCache._initialized = False - - def _quantize(self, tensor, kv_type, layer_idx): - """Quantizes a key/value using a defined quantization method.""" - from compressed_tensors.quantization.lifecycle.forward import quantize - - if kv_type == KVCacheScaleType.KEY: # key type - observer = self.k_observers[layer_idx] - scales = self.k_scales - zps = self.k_zps - else: - assert kv_type == KVCacheScaleType.VALUE - observer = self.v_observers[layer_idx] - scales = self.v_scales - zps = self.v_zps - - scale, zp = observer(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) - - q_tensor = quantize( - x=tensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return q_tensor - - def _dequantize(self, qtensor, kv_type, layer_idx): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - from compressed_tensors.quantization.lifecycle.forward import dequantize - - if kv_type == KVCacheScaleType.KEY: - scale = self.k_scales[layer_idx] - zp = self.k_zps[layer_idx] - else: - assert kv_type == KVCacheScaleType.VALUE - scale = self.v_scales[layer_idx] - zp = self.v_zps[layer_idx] - - qdq_tensor = dequantize( - x_q=qtensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return qdq_tensor - - -# NOTE: Using _ suffix to denote l is modified in place -def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: - """ - Append value val to list lst at index idx, right padding if necessary - Needed because user may ignore some layers in configuration, meaning - len(lst) <= idx-1 - - >>> _pad_and_append_at_idx_([0,1,2], 5, 5) - [0, 1, 2, None, None, 5] - >>> _pad_and_append_at_idx_([0,1,2], 3, 8) - [0, 1, 2, 8] - >>> _pad_and_append_at_idx_([0,1,2], 1, 5) - [0, 5, 2] - """ - num_to_pad = idx - len(lst) + 1 - if num_to_pad > 0: - lst += [None] * num_to_pad - lst[idx] = val - return lst diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index b10a4cb31..5316645b0 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,20 +1,17 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import torch from compressed_tensors.quantization import ( DynamicType, - KVCacheScaleType, - QuantizationScheme, + QuantizationArgs, QuantizationStatus, QuantizationStrategy, ) from compressed_tensors.quantization.lifecycle.forward import forward_quantize -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain @@ -29,9 +26,6 @@ "update_weight_zp_scale", "calibrate_input_hook", "calibrate_output_hook", - "calibrate_kv_cache_input_hook", - "calibrate_kv_cache_output_hook", - "initialize_quantized_kv_cache", "freeze_module_quantization", "apply_calibration_status", "reset_quantization_status", @@ -42,6 +36,7 @@ def initialize_observer( module: Module, base_name: str, + quantization_args: Optional[QuantizationArgs], ): """ Initialize observer module and attach as submodule. @@ -53,14 +48,6 @@ def initialize_observer( :param base_name: str used to name the observer attribute """ - - arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" - quantization_scheme = getattr(module, "quantization_scheme", None) - if not quantization_scheme: - # no quantization scheme nothing to do - return - - quantization_args = getattr(quantization_scheme, arg_name, None) # dont need observers for dynamic if quantization_args is not None and quantization_args.dynamic in ( False, @@ -235,53 +222,6 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): return output -def calibrate_kv_cache_input_hook( - module: Module, args: Any, kwargs: Dict[str, Any] -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - """ - Hook to update inputs to attention layers when running - kv_cache quantization. Will update the passed in - kv_cache to singleton QuantizedKVParameterCache. - """ - kv_cache = getattr(module, "kv_cache") - kwargs["past_key_value"] = kv_cache - kwargs["use_cache"] = False - return args, kwargs - - -def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): - """ - Hook to update k_scale and v_scale parameters when running kv_cache quantization. - """ - kv_cache = getattr(module, "kv_cache") - k_scale = kv_cache.k_scales[module.layer_idx] - v_scale = kv_cache.v_scales[module.layer_idx] - update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value) - update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value) - - -def initialize_quantized_kv_cache(module: Module): - """ - Initialize a quantized kv_cache on a module (analogous to initializing an observer) - When a config specifying kv_cache quantization is applied to a model, the kv_cache - args are redefined as the output_activations targeting attention modules. - - This function should be called on attention modules with output_activations - """ - scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) - existing_kv_cache = getattr(module, "kv_cache", None) - - if ( - scheme is None - or not is_kv_cache_quant_scheme(scheme) - or isinstance(existing_kv_cache, QuantizedKVParameterCache) - ): - return - - quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations) - setattr(module, "kv_cache", quantized_kv_cache) - - def apply_calibration_status(module: Module): scheme = getattr(module, "quantization_scheme", None) if not scheme: @@ -313,11 +253,6 @@ def freeze_module_quantization(module: Module): if hasattr(module, obs_name): delattr(module, obs_name) - # remove quantized kv_cache - kv_cache = getattr(module, "kv_cache", None) - if isinstance(kv_cache, QuantizedKVParameterCache): - delattr(module, "kv_cache") - module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index d193d85a1..c2db1945c 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -1,8 +1,7 @@ -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch from compressed_tensors.quantization import ( - DynamicType, QuantizationArgs, QuantizationConfig, QuantizationScheme, @@ -20,12 +19,9 @@ from llmcompressor.modifiers.quantization.calibration import ( apply_calibration_status, calibrate_input_hook, - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, calibrate_output_hook, freeze_module_quantization, initialize_observer, - initialize_quantized_kv_cache, reset_quantization_status, ) from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -138,7 +134,9 @@ def start_calibration(self, model: torch.nn.Module): :param model: model to prepare for calibration """ - self._calibration_hooks = self._initialize_hooks(model) + self._calibration_hooks = set() + for module in model.modules(): + self._calibration_hooks |= self._initialize_hooks(module) model.apply(apply_calibration_status) model.apply(enable_quantization) # quantize at the same time as calibrate @@ -209,48 +207,43 @@ def resolve_quantization_config(self) -> QuantizationConfig: ) def _initialize_observers(self, module: torch.nn.Module): - if not hasattr(module, "quantization_scheme"): + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + if scheme is None: return - scheme: QuantizationScheme = module.quantization_scheme - input = scheme.input_activations and scheme.input_activations.dynamic in ( - False, - DynamicType.LOCAL, - ) - weight = scheme.weights is not None - output = scheme.output_activations and not scheme.output_activations.dynamic - is_attention = is_attention_module(module) + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + input, weight, output = _get_observer_targets(scheme) + + # input activations + if input: + initialize_observer(module, "input", scheme.input_activations) - # input activations - if input: - initialize_observer(module, base_name="input") + # weight observers (used by `update_weight_zp_scale` or child modifier) + if weight: + initialize_observer(module, "weight", scheme.weights) - # weight observers (used by `update_weight_zp_scale` or child modifier) - if weight: - initialize_observer(module, base_name="weight") + # output activations + if output: + initialize_observer(module, "output", scheme.output_activations) - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - initialize_quantized_kv_cache(module) + elif is_attention_module(module): + # attention observers + initialize_observer(module, "q", scheme.input_activations) + initialize_observer(module, "k", scheme.input_activations) + initialize_observer(module, "v", scheme.input_activations) - # output activations - elif output: - initialize_observer(module, base_name="output") + else: + raise ValueError(f"Unsupported quantization target {type(module)}") - def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: + def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() - for module in model.modules(): - if not hasattr(module, "quantization_scheme"): - continue + if not hasattr(module, "quantization_scheme"): + return hooks - scheme: QuantizationScheme = module.quantization_scheme - input = scheme.input_activations and scheme.input_activations.dynamic in ( - False, - DynamicType.LOCAL, - ) - output = scheme.output_activations and not scheme.output_activations.dynamic - is_attention = is_attention_module(module) + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + input, _, output = _get_observer_targets(module.quantization_scheme) # input activations if input: @@ -258,25 +251,41 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: self.register_hook(module, calibrate_input_hook, "forward_pre") ) - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - hooks.add( - self.register_hook( - module, - calibrate_kv_cache_input_hook, - "forward_pre", - with_kwargs=True, - ) - ) - hooks.add( - self.register_hook( - module, calibrate_kv_cache_output_hook, "forward" - ) - ) - # output activations - elif output: + if output: hooks.add(self.register_hook(module, calibrate_output_hook, "forward")) + elif is_attention_module(module): + # wrap attention interface + tmp = None + + # This import is purely so that "calibrated_attention" gets registered + # as an attention implementation. There's probably a better way to do this + # so that registration only happens once before the model runs + from llmcompressor.modifiers.quantization.attention import ( # noqa: F401 + calibrated_attention, + ) + + def forward_pre(self, *args, **kwargs): + nonlocal tmp + tmp = self.config._attn_implementation + self.config._attn_implementation = "calibrated_attention" + + def forward(self, *args, **kwargs): + self.config._attn_implementation = tmp + + hooks.add(self.register_hook(module, forward_pre, "forward_pre")) + hooks.add(self.register_hook(module, forward, "forward")) + + else: + raise ValueError(f"Unsupported quantization target {type(module)}") + return hooks + + +def _get_observer_targets(scheme: QuantizationScheme) -> Tuple[bool, bool, bool]: + input = scheme.input_activations and scheme.input_activations.dynamic is not True + weight = scheme.weights is not None + output = scheme.output_activations and scheme.output_activations.dynamic is not True + + return input, weight, output diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py deleted file mode 100644 index 898c342f5..000000000 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs - -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache -from llmcompressor.observers import Observer - - -def test_is_quantized_cache_singleton(): - """ - Check if quantized_cache is a singleton, used for - passing in QuantizedKVParameterCache to the forward call of - the model's self_attn - """ - - args = QuantizationArgs() - cache = QuantizedKVParameterCache(args) - observer = args.observer - observer = Observer.load_from_registry(observer, quantization_args=args) - - tensor = torch.tensor([1, 2, 3]) - cache.k_scales.append(tensor) - cache.k_observers.append(observer) - - same_cache = QuantizedKVParameterCache(args) - - assert len(cache.k_scales) == len(same_cache.k_scales) - assert torch.equal(cache.k_scales[0], same_cache.k_scales[0]) - - assert cache.k_observers == same_cache.k_observers - assert hex(id(cache.k_observers[0])) == hex(id(same_cache.k_observers[0])) - - cache.reset() - - -def test_update(): - nbits = 8 - args = QuantizationArgs(nbits=nbits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - denom = (2 ** (nbits) - 1) / 2 - expected_k_scale = torch.tensor([max_key_states_val / denom]) - expected_v_scale = torch.tensor([max_value_states_val / denom]) - - assert cache.k_scales[0] == expected_k_scale - assert cache.v_scales[0] == expected_v_scale - - # new attn layer - layer_idx = 1 - cache.update(key_states, value_states, layer_idx) - - assert len(cache.k_scales) == 2 - assert len(cache.v_scales) == 2 - - assert len(cache.k_observers) == 2 - assert len(cache.v_observers) == 2 - - cache.reset() - - -def test_cache_reset(): - nbits = 8 - args = QuantizationArgs(nbits=nbits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - assert len(cache.k_scales) == 1 - assert len(cache.v_scales) == 1 - - assert len(cache.k_observers) == 1 - assert len(cache.v_observers) == 1 - - cache.reset() - - # new instance, different memory addr - different_cache = QuantizedKVParameterCache(args) - - assert len(different_cache.k_scales) == 0 - assert len(different_cache.v_scales) == 0 - - assert len(different_cache.k_observers) == 0 - assert len(different_cache.v_observers) == 0 - - assert hex(id(cache)) != hex(id(different_cache)) diff --git a/tests/llmcompressor/modifiers/calibration/test_frozen.py b/tests/llmcompressor/modifiers/calibration/test_frozen.py index 4b89a0084..ec2cddcfe 100644 --- a/tests/llmcompressor/modifiers/calibration/test_frozen.py +++ b/tests/llmcompressor/modifiers/calibration/test_frozen.py @@ -38,7 +38,7 @@ def test_set_module_for_calibration(): initialize_module_for_quantization(layer, quantization_scheme) layer.quantization_status = QuantizationStatus("calibration") - initialize_observer(layer, "weight") + initialize_observer(layer, "weight", quantization_scheme.weights) # should have both input and weight observer after initalizing assert hasattr(layer, "weight_observer") diff --git a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py deleted file mode 100644 index b22e7ec40..000000000 --- a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, - is_attention_module, -) -from transformers import AutoModelForCausalLM - -from llmcompressor.modifiers.quantization.calibration import ( - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, - freeze_module_quantization, - initialize_quantized_kv_cache, -) - -config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "kv_cache_scheme": { - "num_bits": 8, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "config_groups": { - "group_1": { - "weights": { - "num_bits": 4, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, -} - - -def _prep_for_calibration(module: torch.nn.Module): - if is_attention_module(module): - module.register_forward_pre_hook( - calibrate_kv_cache_input_hook, with_kwargs=True - ) - module.register_forward_hook(calibrate_kv_cache_output_hook) - module.quantization_status = QuantizationStatus.CALIBRATION - - -@pytest.mark.parametrize("config", [config]) -def test_kv_cache_quantization(config): - sample = { - name: torch.ones((1, 32)).long() - for name in ["input_ids", "attention_mask", "labels"] - } - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - torch_dtype="auto", - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - model.apply(initialize_quantized_kv_cache) - model.apply(_prep_for_calibration) - - with torch.no_grad(): - _ = model(**sample) - - model.apply(freeze_module_quantization) - - reloaded_config = QuantizationConfig.from_pretrained(model) - - assert ( - config.kv_cache_scheme.model_dump().keys() - == reloaded_config.kv_cache_scheme.model_dump().keys() - ) - assert list(config.kv_cache_scheme.model_dump().values()) == list( - reloaded_config.kv_cache_scheme.model_dump().values() - ) diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index a742a48b2..dd0df158c 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -39,9 +39,9 @@ def test_observers_update(shape, group_size, actorder): output = torch.empty(module.out_features, dtype=module.weight.dtype) initialize_module_for_quantization(module, scheme) - initialize_observer(module, "weight") - initialize_observer(module, "input") - initialize_observer(module, "output") + initialize_observer(module, "weight", scheme.weights) + initialize_observer(module, "input", scheme.input_activations) + initialize_observer(module, "output", scheme.output_activations) for location, value in ( ("weight", module.weight), diff --git a/tests/llmcompressor/observers/test_helpers.py b/tests/llmcompressor/observers/test_helpers.py index 527176019..02f3db1b6 100644 --- a/tests/llmcompressor/observers/test_helpers.py +++ b/tests/llmcompressor/observers/test_helpers.py @@ -56,7 +56,11 @@ def test_get_observer_token_count(): }, ) apply_quantization_config(model, config) - model.apply(lambda module: initialize_observer(module, base_name="input")) + model.apply( + lambda module: initialize_observer( + module, "input", getattr(module, "quantization_scheme", None) + ) + ) model.apply(_prep_for_input_quant_calibration) # start calibration