diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py new file mode 100644 index 000000000..e5f7faea0 --- /dev/null +++ b/examples/transform/quip_example.py @@ -0,0 +1,87 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import QuIPModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + QuIPModifier(transform_type="random-hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +# Apply algorithms. +oneshot( + model=model, + recipe=recipe, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + pipeline="basic", +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 9956d0340..eaa714183 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa +from .quip import QuIPModifier from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/quip/__init__.py b/src/llmcompressor/modifiers/transform/quip/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py new file mode 100644 index 000000000..8c86a1471 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -0,0 +1,131 @@ +from typing import List, Literal, Optional, Union + +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from pydantic import Field, ValidationInfo, field_validator + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modifiers import Modifier + +__all__ = ["QuIPModifier"] + + +class QuIPModifier(Modifier): + """ + Implements the transforms according to + [QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) # noqa: E501 + [QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501 + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achived through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + QuIP and QuIP# apply transforms to every linear layer, two of which are fused into + the model weights and two of which remain as online rotations computed at runtime. + + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-matrix"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: If true, create distinct transforms for each application + :param learnable: If true, attach gradients to transform weights for training + :param ignore: Modules to ignore when attaching transforms + :param transform_config: Optional transform config for overriding provided arguments + """ + + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="hadamard", exclude=True + ) + randomize: bool = Field(default=False, exclude=True) + learnable: bool = Field(default=False, exclude=True) + ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True) + + # optional override for more fine-grained control + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + raise NotImplementedError(f"{info.field_name} is not supported right now") + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.transform_config is not None: + return True + + self.transform_config = self._create_config() + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + apply_transform_config(state.model, self.transform_config) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + def _create_config(self) -> TransformConfig: + return TransformConfig( + config_groups={ + "v": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=["Linear"], + location="input", # non-mergable + ignore=self.ignore, + ), + TransformArgs( + targets=["Linear"], + location="weight_input", + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + ), + "u": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore=self.ignore, + ), + TransformArgs( + targets=["Linear"], + location="output", # non-mergable + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + ), + } + )