diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index 876db7138..6671af923 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -13,7 +13,7 @@ MODEL_ID, torch_dtype="auto", ) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, attn_implementation="eager") # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" @@ -58,8 +58,10 @@ def tokenize(sample): # * apply spinquant transforms to model in order to make quantization easier # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), - QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + SpinQuantModifier( + rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard" + ), + # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. @@ -75,9 +77,12 @@ def tokenize(sample): print("\n\n") print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) +from llmcompressor.utils import calibration_forward_context + +with calibration_forward_context(model): + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) print("==========================================\n\n") # Save to disk compressed. diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index b10a4cb31..fe824695e 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Tuple +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple import torch from compressed_tensors.quantization import ( @@ -13,11 +14,18 @@ from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module +from torch.utils.hooks import RemovableHandle from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain +if TYPE_CHECKING: + from compressed_tensors.modeling.attention import CompressedAttentionImpl + + from llmcompressor.modifiers.utils.hooks import HooksMixin + + DEFAULT_MAXSHRINK = 0.20 DEFAULT_PATIENCE = 5 DEFAULT_AVERAGING_CONSTANT = 0.01 @@ -25,6 +33,7 @@ DEFAULT_NORM = 2.4 __all__ = [ + "register_calibrate_attn_hooks", "initialize_observer", "update_weight_zp_scale", "calibrate_input_hook", @@ -205,14 +214,30 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): ) -def calibrate_input_hook(module: Module, args: Any): +def register_calibrate_attn_hooks( + modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl" +) -> Set[RemovableHandle]: + return { + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="q"), "query" + ), + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="k"), "key" + ), + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="v"), "value" + ), + } + + +def calibrate_input_hook(module: Module, args: Any, base_name: str = "input"): """ Hook to calibrate input activations. Will call the observers to update the scales/zp before applying input QDQ in the module's forward pass. """ args = args[0] if isinstance(args, tuple) else args - calibrate_activations(module, value=args, base_name="input") + calibrate_activations(module, value=args, base_name=base_name) def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): @@ -282,6 +307,14 @@ def initialize_quantized_kv_cache(module: Module): setattr(module, "kv_cache", quantized_kv_cache) +def initialize_attention_observers(module: Module): + input_args = getattr_chain(module, "quantization_scheme.input_activations", None) + if input_args is not None: + initialize_observer(module, "q", input_args) + initialize_observer(module, "k", input_args) + initialize_observer(module, "v", input_args) + + def apply_calibration_status(module: Module): scheme = getattr(module, "quantization_scheme", None) if not scheme: diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index d193d85a1..7c7a41033 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -232,6 +232,7 @@ def _initialize_observers(self, module: torch.nn.Module): # kv_cache activations. Within `apply_quantization_config`, the config is # modified to use attention output quantization if a kv_cache_scheme exists if is_attention and output: + # initialize_attention_observers(module) # TODO: attnq initialize_quantized_kv_cache(module) # output activations @@ -240,6 +241,7 @@ def _initialize_observers(self, module: torch.nn.Module): def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() + for module in model.modules(): if not hasattr(module, "quantization_scheme"): continue @@ -258,6 +260,11 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: self.register_hook(module, calibrate_input_hook, "forward_pre") ) + # TODO: attnq + # if is_attention: + # attention_impl = CompressedAttentionImpl.from_module(module) + # hooks |= register_calibrate_attn_hooks(self, attention_impl) + # kv_cache activations. Within `apply_quantization_config`, the config is # modified to use attention output quantization if a kv_cache_scheme exists if is_attention and output: diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 5978b93ea..bd78525d3 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -109,7 +109,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: config_groups["R2"] = self._create_r2_scheme(state.model) if SpinquantRotation.R3 in self.rotations: - config_groups["R3"] = self._create_r3_scheme() + config_groups["R3"] = self._create_r3_scheme(state.model) if SpinquantRotation.R4 in self.rotations: config_groups["R4"] = self._create_r4_scheme() @@ -214,8 +214,47 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ], ) - def _create_r3_scheme(self) -> TransformScheme: - raise NotImplementedError() + def _create_r3_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + head_dim=head_dim, + apply=[ + TransformArgs( + targets=[self.mappings.attn], + location="attn_q", + ), + TransformArgs( + targets=[self.mappings.attn], + location="attn_k", + ), + ], + ) def _create_r4_scheme(self) -> TransformScheme: - raise NotImplementedError() + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[*self.mappings.mlp_out], + location="input", + ), + TransformArgs( + targets=[*self.mappings.mlp_out], + location="weight_input", + inverse=True, + ), + ], + ) diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index 7dc327b78..36102b975 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -10,6 +10,7 @@ class SpinQuantMapping(BaseModel): embedding: str + attn: str attn_q: str attn_k: str attn_v: str @@ -31,6 +32,7 @@ def cast_to_list(cls, value): _default_mappings = SpinQuantMapping( embedding="re:.*embed_tokens$", + attn="re:.*self_attn$", attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", attn_v="re:.*v_proj$",