diff --git a/load_transform_model.py b/load_transform_model.py new file mode 100644 index 000000000..27d120707 --- /dev/null +++ b/load_transform_model.py @@ -0,0 +1,111 @@ +from pathlib import Path + +from safetensors import safe_open +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.utils.quantization_config import CompressedTensorsConfig + +MODEL_ID = "/home/dsikka/Llama-3.2-1B-Instruct-W4A16-uncompressed-hadamard-random-debug" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="auto", + torch_dtype="auto", + quantization_config=CompressedTensorsConfig(run_compressed=False), +) +breakpoint() +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) + +import lm_eval + +results = lm_eval.simple_evaluate( + model="hf", + model_args={ + "pretrained": MODEL_ID, + "add_bos_token": True, + "quantization_config": CompressedTensorsConfig(run_compressed=False), + }, + tasks=["gsm8k"], + num_fewshot=8, + limit=1000, + device="cuda:0", + batch_size=100, +) +print(results["results"]) +""" +For: Llama-3.2-1B-Instruct + +Dense: +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.379, + 'exact_match_stderr,strict-match': 0.015349091002225352, + 'exact_match,flexible-extract': 0.381, + 'exact_match_stderr,flexible-extract': 0.015364734787007436}} + +----------------------------MINMAX ---------------------------: + +QantModifier - NO TRANSFORMS +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.177, +'exact_match_stderr,strict-match': 0.011743632866916145, +'exact_match,flexible-extract': 0.179, +'exact_match_stderr,flexible-extract': 0.0117721103708122}} + +QuantModifier - TRANSFORMS (random) +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.231, +'exact_match_stderr,strict-match': 0.012997843819031815, +'exact_match,flexible-extract': 0.236, +'exact_match_stderr,flexible-extract': 0.01301973553930782}} + +GPTQ +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.243, +'exact_match_stderr,strict-match': 0.013569640199177434, +'exact_match,flexible-extract': 0.244, +'exact_match_stderr,flexible-extract': 0.013588548437881431}} + + +---------------------------MSE-----------------------------------: +QuantModifier - No Transforms +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.195, +'exact_match_stderr,strict-match': 0.012535235623319334, +'exact_match,flexible-extract': 0.195, + 'exact_match_stderr,flexible-extract': 0.012535235623319334}} + +QuantModifier - With Transforms (random) +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.243, +'exact_match_stderr,strict-match': 0.013569640199177457, +'exact_match,flexible-extract': 0.244, + 'exact_match_stderr,flexible-extract': 0.013588548437881412}} + +QuantModifier - With Transforms (not random, not normalized ) +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.261, +'exact_match_stderr,strict-match': 0.013895037677965126, +'exact_match,flexible-extract': 0.262, +'exact_match_stderr,flexible-extract': 0.013912208651021352}} + +QuantModifier - With Transforms (not random, normalized) +{'gsm8k': {'alias': 'gsm8k', +'exact_match,strict-match': 0.27, +'exact_match_stderr,strict-match': 0.014046255632633915, +'exact_match,flexible-extract': 0.27, + 'exact_match_stderr,flexible-extract': 0.014046255632633915}} + +GPTQ: +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.285, +'exact_match_stderr,strict-match': 0.014282120955200484, +'exact_match,flexible-extract': 0.286, +'exact_match_stderr,flexible-extract': 0.01429714686251791}} + +---------------------8bit----------------------------------: +QuantModifier - with Transforms (not random, normalized) +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.371, +'exact_match_stderr,strict-match': 0.015283736211823187, +'exact_match,flexible-extract': 0.372, +'exact_match_stderr,flexible-extract': 0.015292149942040577}} + +GPTQ +{'gsm8k': {'alias': 'gsm8k', 'exact_match,strict-match': 0.364, +'exact_match_stderr,strict-match': 0.01522286884052202, + 'exact_match,flexible-extract': 0.365, + 'exact_match_stderr,flexible-extract': 0.015231776226264903}} +""" diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index bcb4b7433..23a987fd4 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -2,8 +2,8 @@ import torch from compressed_tensors.quantization import QuantizationStatus, is_attention_module -from compressed_tensors.quantization.lifecycle.forward import forward_quantize from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme +from compressed_tensors.transforms.apply import apply_transforms_to_activations_or_parameter from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data from loguru import logger from torch.nn import Module @@ -120,8 +120,21 @@ def update_weight_zp_scale(module: Module): if module.quantization_scheme.weights is not None: # set weight scale and zero_point up front, calibration data doesn't affect it + + transform_data = getattr(module, "transform_data", None) + if transform_data is not None: + untransformed_weight = module.weight.data.clone() + apply_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=module.weight, + transform_data=transform_data, + ) + call_observer(module=module, base_name="weight") + if transform_data is not None: + module.weight.data.copy_(untransformed_weight) + def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): """ @@ -138,11 +151,22 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): if value.numel() == 0: return + transform_data = getattr(module, "transform_data", None) + if transform_data is not None: + value = apply_transforms_to_activations_or_parameter( + module=module, + module_activation_or_parameter=value, + transform_data=transform_data, + update_in_place=False + ) + call_observer( module=module, base_name=base_name, value=value, ) + breakpoint() + # validate value is correct def calibrate_input_hook(module: Module, args: Any): @@ -166,12 +190,6 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): value=output, base_name="output", ) - output = forward_quantize( - module=module, - value=output, - base_name="output", - args=module.quantization_scheme.output_activations, - ) return output @@ -197,7 +215,6 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale") update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale") - def set_unset_kv_cache(module: Module): """ Set or unset singleton QuantizedKVParameterCache for each diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 3a8946aef..b514d0c0b 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Set from compressed_tensors.quantization import ( QuantizationArgs, @@ -10,6 +10,12 @@ is_preset_scheme, preset_name_to_scheme, ) +from compressed_tensors.quantization.lifecycle import ( + post_forward_quantize, + pre_forward_quantize, + register_quantization_hooks, +) +from compressed_tensors.transforms.transform_config import TransformationConfig from loguru import logger from pydantic import Field, field_validator from torch.nn import Module @@ -74,6 +80,7 @@ class QuantizationModifier(Modifier): """ config_groups: Optional[Dict[str, QuantizationScheme]] = None + transforms_config: Optional[TransformationConfig] = None ignore: List[str] = Field(default_factory=list) targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) scheme: Optional[Union[str, Dict[str, Any]]] = None @@ -83,6 +90,7 @@ class QuantizationModifier(Modifier): calibration_dataloader_: Any = None calibration_function_: Any = None + _handles: Set = set() @field_validator("targets", mode="before") def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: @@ -210,7 +218,12 @@ def _check_calibration_data(self, config: QuantizationConfig): def _apply_modifier_to_model(self, model: Module): modifier_as_config = self.create_init_config() # Add step to attach kv_cache to the model, if present within the config - apply_quantization_config(model, modifier_as_config) + apply_quantization_config( + model, + modifier_as_config, + transforms_config=self.transforms_config, + delay_forward_quantize=True, + ) model.apply(set_unset_kv_cache) return modifier_as_config @@ -258,6 +271,9 @@ def _calibrate_if_possible(self, module: Module): ) elif not self.calibration_dataloader_: + # TODO: should just use HooksMixin + # hooks should have been delayed + module.apply(lambda model: register_quantization_hooks(model)) return module.apply(lambda model: initialize_observer(model, base_name="input")) @@ -265,7 +281,7 @@ def _calibrate_if_possible(self, module: Module): module.apply(self.register_calibration_hooks) self._calibrate(module) module.apply(set_unset_kv_cache) - self.remove_hooks() + self.remove_hooks(self._handles) def register_calibration_hooks(self, module: Module): """ @@ -285,23 +301,39 @@ def register_calibration_hooks(self, module: Module): # Calibrate inputs if an input_quant is provided and not running dynamic quant if calibrate_inputs: - self.register_hook(module, calibrate_input_hook, "forward_pre") + self._handles.add( + self.register_hook(module, calibrate_input_hook, "forward_pre") + ) + + if not is_attention_module_: + self.register_hook(module, pre_forward_quantize, "forward_pre") if output_quant: # hooks for attn modules if running kv_cache quant if is_attention_module_: - self.register_hook( - module, - calibrate_kv_cache_input_hook, - "forward_pre", - with_kwargs=True, + self._handles.add( + self.register_hook( + module, + calibrate_kv_cache_input_hook, + "forward_pre", + with_kwargs=True, + ) ) - self.register_hook(module, calibrate_kv_cache_output_hook, "forward") + self._handles.add( + self.register_hook( + module, calibrate_kv_cache_output_hook, "forward" + ) + ) # hooks for output quant if not running dynamic quant elif not output_quant.dynamic: - self.register_hook(module, calibrate_output_hook, "forward") + self._handles.add( + self.register_hook(module, calibrate_output_hook, "forward") + ) + + if not is_attention_module_: + self.register_hook(module, post_forward_quantize, "forward") def _calibrate(self, module: Module): class_name = self.__class__.__name__.replace("PyTorch", "") diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 75fad8311..6c5919652 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1125,7 +1125,6 @@ def calibration_forward_context(model: PreTrainedModel): with ( torch.no_grad(), DisableKVCache(model), - DisableQuantization(model), eval_context(model), ): yield diff --git a/weight_transform.py b/weight_transform.py new file mode 100644 index 000000000..74b61078f --- /dev/null +++ b/weight_transform.py @@ -0,0 +1,145 @@ +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, +) +from compressed_tensors.transforms import Hadamard, RandomHadamard, Transforms +from compressed_tensors.transforms.transform_args import ( + ModuleTarget, + TransformationArgs, +) +from compressed_tensors.transforms.transform_config import TransformationConfig +from compressed_tensors.transforms.transform_scheme import TransformationScheme +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +# U(W)V.T + +ignore = ["re:.*.mlp.down_proj$", "lm_head"] +module_targets = [ModuleTarget.WEIGHT.value] + +# Start with a processed +targets = ["Linear"] # 2048 * 2048 +v_linear_args = TransformationArgs( + targets=targets, + module_targets=module_targets, + ignore=ignore, + call_args={"transpose": True, "first": False}, +) + +targets = ["re:.*.mlp.down_proj$"] # 8192 * 8192 +v_down_proj = TransformationArgs( + targets=targets, + module_targets=module_targets, + call_args={"transpose": True, "first": False}, +) + +targets = [ + "re:.*.attn.q_proj$", + "re:.*.attn.o_proj$", + "re:.*.mlp.down_proj$", +] # 2048 * 2048 +u_q_o_down_proj = TransformationArgs( + targets=targets, + module_targets=module_targets, +) + +targets = ["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"] # 8192 * 8192 +u_gate_up_proj = TransformationArgs( + targets=targets, + module_targets=module_targets, +) + +targets = ["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"] # 512 * 512 +u_k_v_proj = TransformationArgs( + targets=targets, + module_targets=module_targets, +) + + +# This will apply the random_had to the first set of args +# It will then apply the second set of args +# any overalp will be applied in order +v_scheme = TransformationScheme( + transform_type="hadamard", + groups=[v_linear_args], + transform_creation_args={"size": 2048}, +) + +v_scheme_down_proj = TransformationScheme( + transform_type="hadamard", + groups=[v_down_proj], + transform_creation_args={"size": 8192}, +) + +# We could combine multiple args to the same scheme but then would make it more difficult to consolidate order of transforms +u_scheme_q_o_down_proj = TransformationScheme( + transform_type="hadamard", + groups=[u_q_o_down_proj], + transform_creation_args={"size": 2048}, +) + +u_scheme_gate_up_proj = TransformationScheme( + transform_type="hadamard", + groups=[u_gate_up_proj], + transform_creation_args={"size": 8192}, +) + +u_scheme_k_v_proj = TransformationScheme( + transform_type="hadamard", + groups=[u_k_v_proj], + transform_creation_args={"size": 512}, +) + +# QuIP Recipe with weight only quantization +config = TransformationConfig( + transform_groups={ + "u_transform_q_o_down_proj": u_scheme_q_o_down_proj, + "u_transform_k_v_proj": u_scheme_k_v_proj, + "u_transform_gate_up_proj": u_scheme_gate_up_proj, + "v_transform_linear": v_scheme, + "v_transform_down_proj": v_scheme_down_proj, + } +) + +recipe = QuantizationModifier( + targets="Linear", + ignore=["lm_head"], + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + symmetric=True, + strategy=QuantizationStrategy.GROUP, + group_size=128, + observer="mse" + ), + ) + }, + transforms_config=config, +) + +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto" +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +oneshot(model=model, recipe=recipe) + +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-uncompressed-hadamard-random-debug" + +model.save_pretrained(SAVE_DIR, save_compressed=False) +tokenizer.save_pretrained(SAVE_DIR) \ No newline at end of file