From ba617db6a28f02481a8c6604878243af0393a85f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Jun 2025 00:13:50 -0400 Subject: [PATCH 1/4] wip Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 84 +++++++++++++++++++ .../modifiers/transform/__init__.py | 3 + .../modifiers/transform/template/quip.py | 41 +++++++++ .../modifiers/transform/template/spinquant.py | 65 ++++++++++++++ .../modifiers/transform/transform.py | 28 +++++++ 5 files changed, 221 insertions(+) create mode 100644 examples/transform/llama3_example.py create mode 100644 src/llmcompressor/modifiers/transform/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/template/quip.py create mode 100644 src/llmcompressor/modifiers/transform/template/spinquant.py create mode 100644 src/llmcompressor/modifiers/transform/transform.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py new file mode 100644 index 000000000..0c976874d --- /dev/null +++ b/examples/transform/llama3_example.py @@ -0,0 +1,84 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="auto", + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + TransformModifier(), + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +] + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py new file mode 100644 index 000000000..85e8972b4 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .transform import TransformModifier \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/template/quip.py new file mode 100644 index 000000000..070fec03a --- /dev/null +++ b/src/llmcompressor/modifiers/transform/template/quip.py @@ -0,0 +1,41 @@ +from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig + + +QUIP = TransformConfig( + config_groups={ + "v": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="input", # non-mergable + ignore="lm_head", + ), + TransformArgs( + targets=["Linear"], + location="weight_input", + inverse=True, + ignore="lm_head", + ), + ], + randomize=True, + ), + "u": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore="lm_head", + ), + TransformArgs( + targets=["Linear"], + location="output", # non-mergable + inverse=True, + ignore="lm_head" + ), + ], + randomize=True, + ), + } +) \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/template/spinquant.py new file mode 100644 index 000000000..b9d7c5844 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/template/spinquant.py @@ -0,0 +1,65 @@ +from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig + + +LLAMA_SPINQUANT = TransformConfig( + transform_groups={ + "R1": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["embed_tokens", "o_proj", "down_proj"], + location="weight_output", + ), + TransformArgs( + targets=[ + "q_proj", + "k_proj", + "v_proj", + "up_proj", + "gate_proj", + "lm_head", + ], + location="weight_input", + inverse=True, + ), + ], + ), + "R2": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["v_proj"], + location="weight_output", + ), + TransformArgs( + targets=["o_proj"], location="weight_input", inverse=True + ), + ], + ), + "R3": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["self_attn"], + location="k_cache", + ), + TransformArgs( + targets=["self_attn"], + location="q_attn", + ), + ], + ), + "R4": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["down_proj"], + location="input", + ), + TransformArgs( + targets=["down_proj"], location="weight_input", inverse=True + ), + ], + ), + } +) \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py new file mode 100644 index 000000000..1700e12f1 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -0,0 +1,28 @@ +from typing import Dict, Optional + +from llmcompressor.core import State +from llmcompressor.modifiers import Modifier + +from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory + +from .template.quip import QUIP + +class TransformModifier(Modifier): + preset_config: Optional[str] = None + config_groups: Optional[Dict[str, TransformScheme]] = None + + # model validator to validate both preset and config gropus are not provided + + def on_initialize(self, state: State, **kwargs): + if self.preset_config is not None: + # import config template and customize to model + pass + + + #config = TransformConfig(config_groups=self.config_groups) + config = QUIP + + # TODO: use CT-provided apply_transform_config + for name, scheme in config.config_groups.items(): + factory = TransformFactory.from_scheme(scheme, name=name) + factory.apply_to_model(state.model) \ No newline at end of file From 2f5b1c8a20ddffd9f83cf2984d36007bf7cdefe5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 11 Jun 2025 23:15:46 -0400 Subject: [PATCH 2/4] use random-hadamard, add correctness tests Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 2 +- src/llmcompressor/modifiers/transform/__init__.py | 2 +- .../modifiers/transform/template/quip.py | 11 +++++------ .../modifiers/transform/template/spinquant.py | 5 ++--- src/llmcompressor/modifiers/transform/transform.py | 14 ++++++-------- 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 0c976874d..41bb4921c 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -58,7 +58,7 @@ def tokenize(sample): # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ TransformModifier(), - GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 85e8972b4..6c65678af 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,3 @@ # flake8: noqa -from .transform import TransformModifier \ No newline at end of file +from .transform import TransformModifier diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/template/quip.py index 070fec03a..e39c32e6d 100644 --- a/src/llmcompressor/modifiers/transform/template/quip.py +++ b/src/llmcompressor/modifiers/transform/template/quip.py @@ -1,10 +1,9 @@ -from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig - +from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme QUIP = TransformConfig( config_groups={ "v": TransformScheme( - type="hadamard", + type="random-hadamard", apply=[ TransformArgs( targets=["Linear"], @@ -21,7 +20,7 @@ randomize=True, ), "u": TransformScheme( - type="hadamard", + type="random-hadamard", apply=[ TransformArgs( targets=["Linear"], @@ -32,10 +31,10 @@ targets=["Linear"], location="output", # non-mergable inverse=True, - ignore="lm_head" + ignore="lm_head", ), ], randomize=True, ), } -) \ No newline at end of file +) diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/template/spinquant.py index b9d7c5844..d628cbfd9 100644 --- a/src/llmcompressor/modifiers/transform/template/spinquant.py +++ b/src/llmcompressor/modifiers/transform/template/spinquant.py @@ -1,5 +1,4 @@ -from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig - +from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme LLAMA_SPINQUANT = TransformConfig( transform_groups={ @@ -62,4 +61,4 @@ ], ), } -) \ No newline at end of file +) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 1700e12f1..6cd1417b5 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -1,12 +1,13 @@ from typing import Dict, Optional +from compressed_tensors.transform import TransformScheme, apply_transform_config + from llmcompressor.core import State from llmcompressor.modifiers import Modifier -from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory - from .template.quip import QUIP + class TransformModifier(Modifier): preset_config: Optional[str] = None config_groups: Optional[Dict[str, TransformScheme]] = None @@ -18,11 +19,8 @@ def on_initialize(self, state: State, **kwargs): # import config template and customize to model pass - - #config = TransformConfig(config_groups=self.config_groups) + # config = TransformConfig(config_groups=self.config_groups) config = QUIP - # TODO: use CT-provided apply_transform_config - for name, scheme in config.config_groups.items(): - factory = TransformFactory.from_scheme(scheme, name=name) - factory.apply_to_model(state.model) \ No newline at end of file + apply_transform_config(state.model, config) + breakpoint() From 3aa35e727143ee35cc1226fe86d863de8eff85df Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 11 Jun 2025 23:22:24 -0400 Subject: [PATCH 3/4] add correctness test, note that precision makes a large difference Signed-off-by: Kyle Sayers --- .../modifiers/transform/test_correctness.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/llmcompressor/modifiers/transform/test_correctness.py diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py new file mode 100644 index 000000000..8fca9639b --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -0,0 +1,29 @@ +import pytest +import torch +from compressed_tensors.transform import apply_transform_config +from transformers import AutoModelForCausalLM + +from llmcompressor.modifiers.transform.template.quip import QUIP + + +@pytest.mark.parametrize( + "dtype,exp_max,exp_mse", [ + (torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 + (torch.float32, 4e-4, 2e-9) + ] +) +def test_apply_correctness(dtype, exp_max, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype + ) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + apply_transform_config(model, QUIP) + with torch.no_grad(): + output = model(**input) + + assert torch.max(true_output.logits - output.logits) <= exp_max + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse From b6c088e787454b419962544aa2ce9f852b73692a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 23 Jun 2025 20:18:52 +0000 Subject: [PATCH 4/4] add on lifecycle methods Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 4 +-- .../modifiers/transform/transform.py | 33 ++++++++++++++++--- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 41bb4921c..b868d4b2a 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -3,14 +3,13 @@ from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.modifiers.transform import TransformModifier -from llmcompressor.transformers import oneshot +from llmcompressor import oneshot # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, - device_map="auto", torch_dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -66,6 +65,7 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + pipeline="sequential", max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 6cd1417b5..6b8e89927 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -2,7 +2,7 @@ from compressed_tensors.transform import TransformScheme, apply_transform_config -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from .template.quip import QUIP @@ -12,9 +12,9 @@ class TransformModifier(Modifier): preset_config: Optional[str] = None config_groups: Optional[Dict[str, TransformScheme]] = None - # model validator to validate both preset and config gropus are not provided + # model validator to validate both preset and config groups are not provided - def on_initialize(self, state: State, **kwargs): + def on_initialize(self, state: State, **kwargs) -> bool: if self.preset_config is not None: # import config template and customize to model pass @@ -23,4 +23,29 @@ def on_initialize(self, state: State, **kwargs): config = QUIP apply_transform_config(state.model, config) - breakpoint() + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True