From 0a146f8422fca6cd3c13bba71c079ccbf68e2e84 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 16:28:44 -0400 Subject: [PATCH 1/4] hook with CompressedAttentionImpl Signed-off-by: Kyle Sayers --- .../modifiers/quantization/calibration.py | 16 ++++++++++++++++ .../modifiers/quantization/quantization/mixin.py | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index b10a4cb31..84e8ead9c 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -282,6 +282,22 @@ def initialize_quantized_kv_cache(module: Module): setattr(module, "kv_cache", quantized_kv_cache) +def initialize_attention_observers(module: Module): + input_args = getattr_chain(module, "quantization_scheme.input_activations", None) + if input_args is not None: + initialize_observer(module, "q", input_args) + initialize_observer(module, "k", input_args) + initialize_observer(module, "v", input_args) + + +def calibrate_attention( + module: Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +): + calibrate_activations(module, value=query, base_name="q") + calibrate_activations(module, value=key, base_name="k") + calibrate_activations(module, value=value, base_name="v") + + def apply_calibration_status(module: Module): scheme = getattr(module, "quantization_scheme", None) if not scheme: diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index d193d85a1..38cb0f455 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -232,6 +232,7 @@ def _initialize_observers(self, module: torch.nn.Module): # kv_cache activations. Within `apply_quantization_config`, the config is # modified to use attention output quantization if a kv_cache_scheme exists if is_attention and output: + # initialize_attention_observers(module) # TODO: attnq initialize_quantized_kv_cache(module) # output activations @@ -240,6 +241,11 @@ def _initialize_observers(self, module: torch.nn.Module): def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() + + # TODO: attnq + # attention_impl = enable_compressed_attention(model) + # hooks.add(self.register_hook(attention_impl, calibrate_attention, "calib")) + for module in model.modules(): if not hasattr(module, "quantization_scheme"): continue From 0e4e00279f7f529f7d75fa39247587b675f3bfdc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 17:10:14 -0400 Subject: [PATCH 2/4] use qkv hooks Signed-off-by: Kyle Sayers --- .../modifiers/quantization/calibration.py | 37 +++++++++++++------ .../quantization/quantization/mixin.py | 2 +- .../modifiers/transform/spinquant/base.py | 37 ++++++++++++++++++- .../modifiers/transform/spinquant/mappings.py | 2 + 4 files changed, 64 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 84e8ead9c..e460e1f27 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Tuple +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple import torch from compressed_tensors.quantization import ( @@ -13,11 +14,16 @@ from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module +from torch.utils.hooks import RemovableHandle from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain +if TYPE_CHECKING: + from llmcompressor.modifiers.utils.hooks import HooksMixin + + DEFAULT_MAXSHRINK = 0.20 DEFAULT_PATIENCE = 5 DEFAULT_AVERAGING_CONSTANT = 0.01 @@ -25,6 +31,7 @@ DEFAULT_NORM = 2.4 __all__ = [ + "register_calibrate_attn_hooks", "initialize_observer", "update_weight_zp_scale", "calibrate_input_hook", @@ -205,14 +212,30 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): ) -def calibrate_input_hook(module: Module, args: Any): +def register_calibrate_attn_hooks( + modifier: HooksMixin, attention_impl +) -> Set[RemovableHandle]: + return { + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="q"), "query" + ), + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="k"), "key" + ), + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="v"), "value" + ), + } + + +def calibrate_input_hook(module: Module, args: Any, base_name: str = "input"): """ Hook to calibrate input activations. Will call the observers to update the scales/zp before applying input QDQ in the module's forward pass. """ args = args[0] if isinstance(args, tuple) else args - calibrate_activations(module, value=args, base_name="input") + calibrate_activations(module, value=args, base_name=base_name) def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): @@ -290,14 +313,6 @@ def initialize_attention_observers(module: Module): initialize_observer(module, "v", input_args) -def calibrate_attention( - module: Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor -): - calibrate_activations(module, value=query, base_name="q") - calibrate_activations(module, value=key, base_name="k") - calibrate_activations(module, value=value, base_name="v") - - def apply_calibration_status(module: Module): scheme = getattr(module, "quantization_scheme", None) if not scheme: diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 38cb0f455..4f1ee46a2 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -244,7 +244,7 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: # TODO: attnq # attention_impl = enable_compressed_attention(model) - # hooks.add(self.register_hook(attention_impl, calibrate_attention, "calib")) + # hooks |= register_calibrate_attn_hooks(self, attention_impl) for module in model.modules(): if not hasattr(module, "quantization_scheme"): diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 5978b93ea..2204aec13 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -215,7 +215,40 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ) def _create_r3_scheme(self) -> TransformScheme: - raise NotImplementedError() + return ( + TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[self.mappings.attn], + location="attn_q", + ), + TransformArgs( + targets=[self.mappings.attn], + location="attn_k", + ), + ], + ), + ) def _create_r4_scheme(self) -> TransformScheme: - raise NotImplementedError() + return ( + TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[*self.mappings.mlp_out], + location="input", + ), + TransformArgs( + targets=[*self.mappings.mlp_out], + location="weight_input", + inverse=True, + ), + ], + ), + ) diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index 7dc327b78..36102b975 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -10,6 +10,7 @@ class SpinQuantMapping(BaseModel): embedding: str + attn: str attn_q: str attn_k: str attn_v: str @@ -31,6 +32,7 @@ def cast_to_list(cls, value): _default_mappings = SpinQuantMapping( embedding="re:.*embed_tokens$", + attn="re:.*self_attn$", attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", attn_v="re:.*v_proj$", From 5aa35865b5f2e2cdc3b3a8b8f8ca157829910757 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 17:29:06 -0400 Subject: [PATCH 3/4] use get_compressed_attention_impl Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/quantization/mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 4f1ee46a2..71c677039 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -243,7 +243,7 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() # TODO: attnq - # attention_impl = enable_compressed_attention(model) + # attention_impl = get_compressed_attention_impl() # hooks |= register_calibrate_attn_hooks(self, attention_impl) for module in model.modules(): From a9b2f517f34f8be0f9f8cfb59a3a290f697e836b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 17 Jul 2025 17:28:57 -0400 Subject: [PATCH 4/4] r3 r4 works, but not with sdpa Signed-off-by: Kyle Sayers --- examples/transform/spinquant_example.py | 17 +++-- .../modifiers/quantization/calibration.py | 4 +- .../quantization/quantization/mixin.py | 9 ++- .../modifiers/transform/spinquant/base.py | 76 ++++++++++--------- 4 files changed, 60 insertions(+), 46 deletions(-) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index 876db7138..6671af923 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -13,7 +13,7 @@ MODEL_ID, torch_dtype="auto", ) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, attn_implementation="eager") # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" @@ -58,8 +58,10 @@ def tokenize(sample): # * apply spinquant transforms to model in order to make quantization easier # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), - QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + SpinQuantModifier( + rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard" + ), + # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. @@ -75,9 +77,12 @@ def tokenize(sample): print("\n\n") print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) +from llmcompressor.utils import calibration_forward_context + +with calibration_forward_context(model): + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) print("==========================================\n\n") # Save to disk compressed. diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index e460e1f27..fe824695e 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -21,6 +21,8 @@ from llmcompressor.utils.helpers import getattr_chain if TYPE_CHECKING: + from compressed_tensors.modeling.attention import CompressedAttentionImpl + from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -213,7 +215,7 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): def register_calibrate_attn_hooks( - modifier: HooksMixin, attention_impl + modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl" ) -> Set[RemovableHandle]: return { modifier.register_hook( diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 71c677039..7c7a41033 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -242,10 +242,6 @@ def _initialize_observers(self, module: torch.nn.Module): def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() - # TODO: attnq - # attention_impl = get_compressed_attention_impl() - # hooks |= register_calibrate_attn_hooks(self, attention_impl) - for module in model.modules(): if not hasattr(module, "quantization_scheme"): continue @@ -264,6 +260,11 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: self.register_hook(module, calibrate_input_hook, "forward_pre") ) + # TODO: attnq + # if is_attention: + # attention_impl = CompressedAttentionImpl.from_module(module) + # hooks |= register_calibrate_attn_hooks(self, attention_impl) + # kv_cache activations. Within `apply_quantization_config`, the config is # modified to use attention output quantization if a kv_cache_scheme exists if is_attention and output: diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 2204aec13..bd78525d3 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -109,7 +109,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: config_groups["R2"] = self._create_r2_scheme(state.model) if SpinquantRotation.R3 in self.rotations: - config_groups["R3"] = self._create_r3_scheme() + config_groups["R3"] = self._create_r3_scheme(state.model) if SpinquantRotation.R4 in self.rotations: config_groups["R4"] = self._create_r4_scheme() @@ -214,41 +214,47 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ], ) - def _create_r3_scheme(self) -> TransformScheme: - return ( - TransformScheme( - type=self.transform_type, - randomize=self.randomize, - requires_grad=self.learnable, - apply=[ - TransformArgs( - targets=[self.mappings.attn], - location="attn_q", - ), - TransformArgs( - targets=[self.mappings.attn], - location="attn_k", - ), - ], - ), + def _create_r3_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + head_dim=head_dim, + apply=[ + TransformArgs( + targets=[self.mappings.attn], + location="attn_q", + ), + TransformArgs( + targets=[self.mappings.attn], + location="attn_k", + ), + ], ) def _create_r4_scheme(self) -> TransformScheme: - return ( - TransformScheme( - type=self.transform_type, - randomize=self.randomize, - requires_grad=self.learnable, - apply=[ - TransformArgs( - targets=[*self.mappings.mlp_out], - location="input", - ), - TransformArgs( - targets=[*self.mappings.mlp_out], - location="weight_input", - inverse=True, - ), - ], - ), + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[*self.mappings.mlp_out], + location="input", + ), + TransformArgs( + targets=[*self.mappings.mlp_out], + location="weight_input", + inverse=True, + ), + ], )