From ba617db6a28f02481a8c6604878243af0393a85f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Jun 2025 00:13:50 -0400 Subject: [PATCH 01/35] wip Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 84 +++++++++++++++++++ .../modifiers/transform/__init__.py | 3 + .../modifiers/transform/template/quip.py | 41 +++++++++ .../modifiers/transform/template/spinquant.py | 65 ++++++++++++++ .../modifiers/transform/transform.py | 28 +++++++ 5 files changed, 221 insertions(+) create mode 100644 examples/transform/llama3_example.py create mode 100644 src/llmcompressor/modifiers/transform/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/template/quip.py create mode 100644 src/llmcompressor/modifiers/transform/template/spinquant.py create mode 100644 src/llmcompressor/modifiers/transform/transform.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py new file mode 100644 index 000000000..0c976874d --- /dev/null +++ b/examples/transform/llama3_example.py @@ -0,0 +1,84 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="auto", + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + TransformModifier(), + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +] + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py new file mode 100644 index 000000000..85e8972b4 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .transform import TransformModifier \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/template/quip.py new file mode 100644 index 000000000..070fec03a --- /dev/null +++ b/src/llmcompressor/modifiers/transform/template/quip.py @@ -0,0 +1,41 @@ +from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig + + +QUIP = TransformConfig( + config_groups={ + "v": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="input", # non-mergable + ignore="lm_head", + ), + TransformArgs( + targets=["Linear"], + location="weight_input", + inverse=True, + ignore="lm_head", + ), + ], + randomize=True, + ), + "u": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore="lm_head", + ), + TransformArgs( + targets=["Linear"], + location="output", # non-mergable + inverse=True, + ignore="lm_head" + ), + ], + randomize=True, + ), + } +) \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/template/spinquant.py new file mode 100644 index 000000000..b9d7c5844 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/template/spinquant.py @@ -0,0 +1,65 @@ +from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig + + +LLAMA_SPINQUANT = TransformConfig( + transform_groups={ + "R1": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["embed_tokens", "o_proj", "down_proj"], + location="weight_output", + ), + TransformArgs( + targets=[ + "q_proj", + "k_proj", + "v_proj", + "up_proj", + "gate_proj", + "lm_head", + ], + location="weight_input", + inverse=True, + ), + ], + ), + "R2": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["v_proj"], + location="weight_output", + ), + TransformArgs( + targets=["o_proj"], location="weight_input", inverse=True + ), + ], + ), + "R3": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["self_attn"], + location="k_cache", + ), + TransformArgs( + targets=["self_attn"], + location="q_attn", + ), + ], + ), + "R4": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["down_proj"], + location="input", + ), + TransformArgs( + targets=["down_proj"], location="weight_input", inverse=True + ), + ], + ), + } +) \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py new file mode 100644 index 000000000..1700e12f1 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -0,0 +1,28 @@ +from typing import Dict, Optional + +from llmcompressor.core import State +from llmcompressor.modifiers import Modifier + +from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory + +from .template.quip import QUIP + +class TransformModifier(Modifier): + preset_config: Optional[str] = None + config_groups: Optional[Dict[str, TransformScheme]] = None + + # model validator to validate both preset and config gropus are not provided + + def on_initialize(self, state: State, **kwargs): + if self.preset_config is not None: + # import config template and customize to model + pass + + + #config = TransformConfig(config_groups=self.config_groups) + config = QUIP + + # TODO: use CT-provided apply_transform_config + for name, scheme in config.config_groups.items(): + factory = TransformFactory.from_scheme(scheme, name=name) + factory.apply_to_model(state.model) \ No newline at end of file From 2f5b1c8a20ddffd9f83cf2984d36007bf7cdefe5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 11 Jun 2025 23:15:46 -0400 Subject: [PATCH 02/35] use random-hadamard, add correctness tests Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 2 +- src/llmcompressor/modifiers/transform/__init__.py | 2 +- .../modifiers/transform/template/quip.py | 11 +++++------ .../modifiers/transform/template/spinquant.py | 5 ++--- src/llmcompressor/modifiers/transform/transform.py | 14 ++++++-------- 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 0c976874d..41bb4921c 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -58,7 +58,7 @@ def tokenize(sample): # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ TransformModifier(), - GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 85e8972b4..6c65678af 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,3 @@ # flake8: noqa -from .transform import TransformModifier \ No newline at end of file +from .transform import TransformModifier diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/template/quip.py index 070fec03a..e39c32e6d 100644 --- a/src/llmcompressor/modifiers/transform/template/quip.py +++ b/src/llmcompressor/modifiers/transform/template/quip.py @@ -1,10 +1,9 @@ -from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig - +from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme QUIP = TransformConfig( config_groups={ "v": TransformScheme( - type="hadamard", + type="random-hadamard", apply=[ TransformArgs( targets=["Linear"], @@ -21,7 +20,7 @@ randomize=True, ), "u": TransformScheme( - type="hadamard", + type="random-hadamard", apply=[ TransformArgs( targets=["Linear"], @@ -32,10 +31,10 @@ targets=["Linear"], location="output", # non-mergable inverse=True, - ignore="lm_head" + ignore="lm_head", ), ], randomize=True, ), } -) \ No newline at end of file +) diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/template/spinquant.py index b9d7c5844..d628cbfd9 100644 --- a/src/llmcompressor/modifiers/transform/template/spinquant.py +++ b/src/llmcompressor/modifiers/transform/template/spinquant.py @@ -1,5 +1,4 @@ -from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig - +from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme LLAMA_SPINQUANT = TransformConfig( transform_groups={ @@ -62,4 +61,4 @@ ], ), } -) \ No newline at end of file +) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 1700e12f1..6cd1417b5 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -1,12 +1,13 @@ from typing import Dict, Optional +from compressed_tensors.transform import TransformScheme, apply_transform_config + from llmcompressor.core import State from llmcompressor.modifiers import Modifier -from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory - from .template.quip import QUIP + class TransformModifier(Modifier): preset_config: Optional[str] = None config_groups: Optional[Dict[str, TransformScheme]] = None @@ -18,11 +19,8 @@ def on_initialize(self, state: State, **kwargs): # import config template and customize to model pass - - #config = TransformConfig(config_groups=self.config_groups) + # config = TransformConfig(config_groups=self.config_groups) config = QUIP - # TODO: use CT-provided apply_transform_config - for name, scheme in config.config_groups.items(): - factory = TransformFactory.from_scheme(scheme, name=name) - factory.apply_to_model(state.model) \ No newline at end of file + apply_transform_config(state.model, config) + breakpoint() From 3aa35e727143ee35cc1226fe86d863de8eff85df Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 11 Jun 2025 23:22:24 -0400 Subject: [PATCH 03/35] add correctness test, note that precision makes a large difference Signed-off-by: Kyle Sayers --- .../modifiers/transform/test_correctness.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/llmcompressor/modifiers/transform/test_correctness.py diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py new file mode 100644 index 000000000..8fca9639b --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -0,0 +1,29 @@ +import pytest +import torch +from compressed_tensors.transform import apply_transform_config +from transformers import AutoModelForCausalLM + +from llmcompressor.modifiers.transform.template.quip import QUIP + + +@pytest.mark.parametrize( + "dtype,exp_max,exp_mse", [ + (torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 + (torch.float32, 4e-4, 2e-9) + ] +) +def test_apply_correctness(dtype, exp_max, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype + ) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + apply_transform_config(model, QUIP) + with torch.no_grad(): + output = model(**input) + + assert torch.max(true_output.logits - output.logits) <= exp_max + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse From b6c088e787454b419962544aa2ce9f852b73692a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 23 Jun 2025 20:18:52 +0000 Subject: [PATCH 04/35] add on lifecycle methods Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 4 +-- .../modifiers/transform/transform.py | 33 ++++++++++++++++--- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 41bb4921c..b868d4b2a 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -3,14 +3,13 @@ from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.modifiers.transform import TransformModifier -from llmcompressor.transformers import oneshot +from llmcompressor import oneshot # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, - device_map="auto", torch_dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -66,6 +65,7 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + pipeline="sequential", max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 6cd1417b5..6b8e89927 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -2,7 +2,7 @@ from compressed_tensors.transform import TransformScheme, apply_transform_config -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from .template.quip import QUIP @@ -12,9 +12,9 @@ class TransformModifier(Modifier): preset_config: Optional[str] = None config_groups: Optional[Dict[str, TransformScheme]] = None - # model validator to validate both preset and config gropus are not provided + # model validator to validate both preset and config groups are not provided - def on_initialize(self, state: State, **kwargs): + def on_initialize(self, state: State, **kwargs) -> bool: if self.preset_config is not None: # import config template and customize to model pass @@ -23,4 +23,29 @@ def on_initialize(self, state: State, **kwargs): config = QUIP apply_transform_config(state.model, config) - breakpoint() + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True From 320712434f3bdbae7330f3d6dc2a4f0f0224a497 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 2 Jul 2025 15:07:42 +0000 Subject: [PATCH 05/35] TransformModifier with SpinQuant R1&R2 Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 26 +++++------ .../modifiers/transform/__init__.py | 1 + .../modifiers/transform/presets/__init__.py | 8 ++++ .../transform/{template => presets}/quip.py | 0 .../{template => presets}/spinquant.py | 43 +++++++++++++++++++ .../modifiers/transform/transform.py | 31 +++++++------ 6 files changed, 84 insertions(+), 25 deletions(-) create mode 100644 src/llmcompressor/modifiers/transform/presets/__init__.py rename src/llmcompressor/modifiers/transform/{template => presets}/quip.py (100%) rename src/llmcompressor/modifiers/transform/{template => presets}/spinquant.py (61%) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index b868d4b2a..90051c9a8 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -1,9 +1,10 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.modifiers.transform import TransformModifier from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.utils import dispatch_for_generation # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -56,8 +57,8 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - TransformModifier(), - GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. @@ -70,15 +71,16 @@ def tokenize(sample): num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================\n\n") +# # Confirm generations of the quantized model look sane. +# print("\n\n") +# print("========== SAMPLE GENERATION ==============") +# dispatch_for_generation(model) +# input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +# output = model.generate(input_ids, max_new_tokens=100) +# print(tokenizer.decode(output[0])) +# print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 6c65678af..036d35b60 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .transform import TransformModifier +from .transform.presets import TRANSFORM_PRESETS diff --git a/src/llmcompressor/modifiers/transform/presets/__init__.py b/src/llmcompressor/modifiers/transform/presets/__init__.py new file mode 100644 index 000000000..a36bbc4d1 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/presets/__init__.py @@ -0,0 +1,8 @@ +from .quip import QUIP +from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2 + +TRANSFORM_PRESETS = { + "QUIP": QUIP, + "LLAMA_SPINQUANT": LLAMA_SPINQUANT, + "LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2, +} diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/presets/quip.py similarity index 100% rename from src/llmcompressor/modifiers/transform/template/quip.py rename to src/llmcompressor/modifiers/transform/presets/quip.py diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py similarity index 61% rename from src/llmcompressor/modifiers/transform/template/spinquant.py rename to src/llmcompressor/modifiers/transform/presets/spinquant.py index d628cbfd9..194818b38 100644 --- a/src/llmcompressor/modifiers/transform/template/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -1,5 +1,8 @@ from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme +# Ref: https://arxiv.org/pdf/2405.16406 Fig 1 + +# All rotations LLAMA_SPINQUANT = TransformConfig( transform_groups={ "R1": TransformScheme( @@ -62,3 +65,43 @@ ), } ) + + +# Mergeable rotations R1 and R2 only +LLAMA_SPINQUANT_R1R2 = TransformConfig( + config_groups={ + "R1": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["embed_tokens", "o_proj", "down_proj"], + location="weight_output", + ), + TransformArgs( + targets=[ + "q_proj", + "k_proj", + "v_proj", + "up_proj", + "gate_proj", + "lm_head", + ], + location="weight_input", + inverse=True, + ), + ], + ), + "R2": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["v_proj"], + location="weight_output", + ), + TransformArgs( + targets=["o_proj"], location="weight_input", inverse=True + ), + ], + ), + } +) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 6b8e89927..d7ac10aaa 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -1,28 +1,33 @@ -from typing import Dict, Optional +from typing import Optional -from compressed_tensors.transform import TransformScheme, apply_transform_config +from compressed_tensors.transform import TransformConfig, apply_transform_config +from pydantic import ValidationError, model_validator from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier - -from .template.quip import QUIP +from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS class TransformModifier(Modifier): preset_config: Optional[str] = None - config_groups: Optional[Dict[str, TransformScheme]] = None + config: Optional[TransformConfig] = None # model validator to validate both preset and config groups are not provided + @model_validator(mode="after") + def validate_model_after(model: "TransformModifier") -> "TransformModifier": + if model.preset_config is None and model.config is None: + raise ValidationError("Either a config or a preset_config must be provided") + + if model.preset_config is not None: + if model.preset_config not in TRANSFORM_PRESETS: + raise ValidationError( + f"Invalid preset_config '{model.preset_config}' " + f"must be in {TRANSFORM_PRESETS.keys()}" + ) + model.config = TRANSFORM_PRESETS[model.preset_config] def on_initialize(self, state: State, **kwargs) -> bool: - if self.preset_config is not None: - # import config template and customize to model - pass - - # config = TransformConfig(config_groups=self.config_groups) - config = QUIP - - apply_transform_config(state.model, config) + apply_transform_config(state.model, self.config) return True From a88ca3c0ef4866c4239a7f34ca62ed90f9554586 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 2 Jul 2025 18:59:36 +0000 Subject: [PATCH 06/35] spinquant and quip_online, running but outputting gibberish Signed-off-by: Brian Dellabetta --- .../modifiers/transform/__init__.py | 2 +- .../modifiers/transform/presets/__init__.py | 3 +- .../modifiers/transform/presets/quip.py | 58 ++++++++++++++ .../modifiers/transform/presets/spinquant.py | 78 ++++++------------- .../modifiers/transform/transform.py | 2 + .../modifiers/transform/test_correctness.py | 13 +++- 6 files changed, 95 insertions(+), 61 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 036d35b60..c43958136 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa +from .presets import TRANSFORM_PRESETS from .transform import TransformModifier -from .transform.presets import TRANSFORM_PRESETS diff --git a/src/llmcompressor/modifiers/transform/presets/__init__.py b/src/llmcompressor/modifiers/transform/presets/__init__.py index a36bbc4d1..0d4a06b90 100644 --- a/src/llmcompressor/modifiers/transform/presets/__init__.py +++ b/src/llmcompressor/modifiers/transform/presets/__init__.py @@ -1,8 +1,9 @@ -from .quip import QUIP +from .quip import QUIP, QUIP_ONLINE from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2 TRANSFORM_PRESETS = { "QUIP": QUIP, + "QUIP_ONLINE": QUIP_ONLINE, "LLAMA_SPINQUANT": LLAMA_SPINQUANT, "LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2, } diff --git a/src/llmcompressor/modifiers/transform/presets/quip.py b/src/llmcompressor/modifiers/transform/presets/quip.py index e39c32e6d..4ce5e47ae 100644 --- a/src/llmcompressor/modifiers/transform/presets/quip.py +++ b/src/llmcompressor/modifiers/transform/presets/quip.py @@ -38,3 +38,61 @@ ), } ) + +# https://github.com/vllm-project/llm-compressor/blob/b43b27a2f277a5e62be4f8c713b84fd1c7aa116b/weight_transform.py#L24-L105 +QUIP_ONLINE = TransformConfig( + config_groups={ + "u_transform_q_o_down_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=[ + "re:.*.attn.q_proj$", + "re:.*.attn.o_proj$", + "re:.*.mlp.down_proj$", + ], + location="weight_input", + ) + ], + ), + "u_transform_k_v_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"], + location="weight_input", + ) + ], + ), + "u_transform_gate_up_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"], + location="weight_input", + ) + ], + ), + "v_transform_linear": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore=["re:.*.mlp.down_proj$", "lm_head"], + inverse=True, + ) + ], + ), + "v_transform_down_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["re:.*.mlp.down_proj$"], + location="weight_output", + inverse=True, + ) + ], + ), + } +) diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py index 194818b38..555b03fd6 100644 --- a/src/llmcompressor/modifiers/transform/presets/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -2,23 +2,23 @@ # Ref: https://arxiv.org/pdf/2405.16406 Fig 1 -# All rotations -LLAMA_SPINQUANT = TransformConfig( - transform_groups={ +# Mergeable rotations R1 and R2 only +LLAMA_SPINQUANT_R1R2 = TransformConfig( + config_groups={ "R1": TransformScheme( type="hadamard", apply=[ TransformArgs( - targets=["embed_tokens", "o_proj", "down_proj"], + targets=["re:.*embed_tokens$", "re:.*o_proj$", "re:.*down_proj$"], location="weight_output", ), TransformArgs( targets=[ - "q_proj", - "k_proj", - "v_proj", - "up_proj", - "gate_proj", + "re:.*q_proj$", + "re:.*k_proj$", + "re:.*v_proj$", + "re:.*up_proj$", + "re:.*gate_proj$", "lm_head", ], location="weight_input", @@ -30,23 +30,31 @@ type="hadamard", apply=[ TransformArgs( - targets=["v_proj"], + targets=["re:.*v_proj$"], location="weight_output", ), TransformArgs( - targets=["o_proj"], location="weight_input", inverse=True + targets=["re:.*o_proj$"], location="weight_input", inverse=True ), ], ), + } +) + +# All rotations +LLAMA_SPINQUANT = TransformConfig( + config_groups={ + "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], + "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], "R3": TransformScheme( type="hadamard", apply=[ TransformArgs( - targets=["self_attn"], + targets=["re:.*self_attn$"], location="k_cache", ), TransformArgs( - targets=["self_attn"], + targets=["re:.*self_attn$"], location="q_attn", ), ], @@ -55,51 +63,11 @@ type="hadamard", apply=[ TransformArgs( - targets=["down_proj"], + targets=["re:.*down_proj$"], location="input", ), TransformArgs( - targets=["down_proj"], location="weight_input", inverse=True - ), - ], - ), - } -) - - -# Mergeable rotations R1 and R2 only -LLAMA_SPINQUANT_R1R2 = TransformConfig( - config_groups={ - "R1": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["embed_tokens", "o_proj", "down_proj"], - location="weight_output", - ), - TransformArgs( - targets=[ - "q_proj", - "k_proj", - "v_proj", - "up_proj", - "gate_proj", - "lm_head", - ], - location="weight_input", - inverse=True, - ), - ], - ), - "R2": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["v_proj"], - location="weight_output", - ), - TransformArgs( - targets=["o_proj"], location="weight_input", inverse=True + targets=["re:.*down_proj$"], location="weight_input", inverse=True ), ], ), diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index d7ac10aaa..e94a3dc35 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -26,6 +26,8 @@ def validate_model_after(model: "TransformModifier") -> "TransformModifier": ) model.config = TRANSFORM_PRESETS[model.preset_config] + return model + def on_initialize(self, state: State, **kwargs) -> bool: apply_transform_config(state.model, self.config) diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py index 8fca9639b..660bab0ef 100644 --- a/tests/llmcompressor/modifiers/transform/test_correctness.py +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -7,10 +7,15 @@ @pytest.mark.parametrize( - "dtype,exp_max,exp_mse", [ - (torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 - (torch.float32, 4e-4, 2e-9) - ] + "dtype,exp_max,exp_mse", + [ + ( + torch.bfloat16, + 1.1, + 0.012, + ), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 + (torch.float32, 4e-4, 2e-9), + ], ) def test_apply_correctness(dtype, exp_max, exp_mse): model = AutoModelForCausalLM.from_pretrained( From 5bd51df3668e7be7b2ea969bf85e1ae7f528d8ee Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 2 Jul 2025 19:11:20 +0000 Subject: [PATCH 07/35] updated example Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 90051c9a8..62801935e 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -7,7 +7,7 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, @@ -57,6 +57,10 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ + # TODO preset_config="LLAMA_SPINQUANT_R1R2" outputs gibberish + # TODO preset_config="QUIP_ONLINE" outputs gibberish + # preset_config="QUIP" output sensible, but cannot load saved + # checkpoint or run evals (~4hrs to run) TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] @@ -72,12 +76,12 @@ def tokenize(sample): ) # # Confirm generations of the quantized model look sane. -# print("\n\n") -# print("========== SAMPLE GENERATION ==============") -# dispatch_for_generation(model) -# input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -# output = model.generate(input_ids, max_new_tokens=100) -# print(tokenizer.decode(output[0])) +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) # print("==========================================\n\n") # Save to disk compressed. From 3c216dd685fdac9172213b1200be5a5ee91be532 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 8 Jul 2025 21:29:27 +0000 Subject: [PATCH 08/35] DummyModel script Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 23 ++-- examples/transform/spinquant_dummy.py | 112 ++++++++++++++++++ src/llmcompressor/entrypoints/oneshot.py | 3 +- .../modifiers/transform/presets/spinquant.py | 43 ++++--- 4 files changed, 154 insertions(+), 27 deletions(-) create mode 100644 examples/transform/spinquant_dummy.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 62801935e..1ec7b6516 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -7,7 +7,9 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # "meta-llama/Meta-Llama-3-8B-Instruct" +# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +# MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" # TODO hidden size 3072 causes failure when creating hadamard +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, @@ -62,17 +64,18 @@ def tokenize(sample): # preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), - QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. oneshot( model=model, - dataset=ds, recipe=recipe, - pipeline="sequential", - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, + # dataset=ds, + pipeline="datafree", + # max_seq_length=MAX_SEQUENCE_LENGTH, + # num_calibration_samples=NUM_CALIBRATION_SAMPLES, + log_dir=None, ) # # Confirm generations of the quantized model look sane. @@ -84,7 +87,7 @@ def tokenize(sample): print(tokenizer.decode(output[0])) # print("==========================================\n\n") -# Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) +# # Save to disk compressed. +# SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" +# model.save_pretrained(SAVE_DIR, save_compressed=True) +# tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/transform/spinquant_dummy.py b/examples/transform/spinquant_dummy.py new file mode 100644 index 000000000..494b6c611 --- /dev/null +++ b/examples/transform/spinquant_dummy.py @@ -0,0 +1,112 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from compressed_tensors.utils import update_parameter_data +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.utils import dispatch_for_generation +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, +) + +hidden_dim = intermediate_dim = 64 +up_dim = 128 +num_embeddings = 12 + + +# TODO remove file before merging + + +class DummySelfAttn(torch.nn.Module): + def __init__(self, hidden_dim, intermediate_dim): + super().__init__() + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) + self.k_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) + self.v_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) + self.o_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) + self.num_heads = 1 + self.num_key_value_groups = 1 + + def forward(self, hidden_states): + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + ### EAGER ATTENTION + attn_weights = torch.matmul(q.T, k) + + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_output = torch.matmul(attn_weights, v.T) + attn_output = attn_output.T.contiguous() + + return self.o_proj(attn_output) + + +class DummyMLP(torch.nn.Module): + def __init__(self, hidden_dim, up_dim): + super().__init__() + self.up_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) + self.gate_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) + self.down_proj = torch.nn.Linear(up_dim, hidden_dim, bias=None) + self.act_fn = torch.nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class DummyModel(torch.nn.Module): + def __init__(self, num_embeddings, hidden_dim, intermediate_dim, up_dim): + super().__init__() + self.embed_tokens = torch.nn.Embedding(num_embeddings, hidden_dim) + self.input_layernorm = LlamaRMSNorm(hidden_dim) + self.post_attention_layernorm = LlamaRMSNorm(hidden_dim) + self.self_attn = DummySelfAttn(hidden_dim, intermediate_dim) + self.mlp = DummyMLP(hidden_dim, up_dim) + self.lm_head = torch.nn.Linear(hidden_dim, num_embeddings, bias=None) + + def forward(self, input_ids): + x = self.embed_tokens(input_ids) + x = self.input_layernorm(x) + x = self.self_attn(x) + x = self.post_attention_layernorm(x) + x = self.mlp(x) + return self.lm_head(x) + + +model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) + +# TODO Uncomment this to see norm diff > 1e-6 +# This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 +# Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) +# https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 +# update_parameter_data( +# model.input_layernorm, +# torch.rand(model.input_layernorm.weight.shape), +# "weight", +# ) + +input_ids = torch.IntTensor([1, 2, 3, 4, 5]) +orig_output = model(input_ids) + +recipe = [ + # NOTE: preset_config="QUIP" output sensible, but cannot load saved + # checkpoint or run evals (~4hrs to run) + TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +oneshot( + model=model, + recipe=recipe, + pipeline="datafree", + log_dir=None, +) + +# # Confirm generations of the quantized model look the same +transformed_output = model(input_ids) + +print(f"Norm Diff {(orig_output-transformed_output).norm()}") +print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 55de99501..df815aa4f 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -125,7 +125,8 @@ def __init__( self.output_dir = output_dir # initialize the model and processor - pre_process(model_args) + # TODO Remove Comment before merge, this is just needed for DummyModel + # pre_process(model_args) # Set instance attributes self.model = self.model_args.model diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py index 555b03fd6..d9765a6d5 100644 --- a/src/llmcompressor/modifiers/transform/presets/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -9,43 +9,54 @@ type="hadamard", apply=[ TransformArgs( - targets=["re:.*embed_tokens$", "re:.*o_proj$", "re:.*down_proj$"], + targets=[ + # outermost rotation + "re:.*embed_tokens$", + # attention rotations + "re:.*o_proj$", + # mlp rotations + "re:.*down_proj$", + ], location="weight_output", ), TransformArgs( targets=[ + # outermost rotation + "lm_head", + # attention rotations "re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$", + # mlp rotations "re:.*up_proj$", "re:.*gate_proj$", - "lm_head", ], location="weight_input", inverse=True, ), ], ), - "R2": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*v_proj$"], - location="weight_output", - ), - TransformArgs( - targets=["re:.*o_proj$"], location="weight_input", inverse=True - ), - ], - ), + # "R2": TransformScheme( + # type="hadamard", + # # TODO infer head_dim from config.json in SpinQuantModifier + # head_dim=128, + # apply=[ + # TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), + # TransformArgs( + # targets=["re:.*o_proj$"], + # location="weight_input", + # inverse=True, + # ), + # ], + # ), } ) # All rotations LLAMA_SPINQUANT = TransformConfig( config_groups={ - "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], - "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], + # "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], + # "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], "R3": TransformScheme( type="hadamard", apply=[ From bbcdc8ca6cd0c055e9baa543fe91fd0c10b88a11 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 9 Jul 2025 23:12:31 -0400 Subject: [PATCH 09/35] implement fuse_norm_linears Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/fuse.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/llmcompressor/modeling/fuse.py diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py new file mode 100644 index 000000000..4a9a34bb3 --- /dev/null +++ b/src/llmcompressor/modeling/fuse.py @@ -0,0 +1,28 @@ +from typing import Iterable + +import torch +from compressed_tensors import update_offload_parameter + +__all__ = ["fuse_norm_linears"] + + +def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): + """ + Fuse a norm layer into subsequent linear layers. This useful for ensuring transform + invariance between norm and linear layers. + + Note that a model cannot be properly trained after its norms have been fused + + :param norm: norm layer whose weight will be fused into subsequent linears + :param linears: linear layers which directly follow the norm layer + """ + if isinstance(norm, torch.nn.RMSNorm): + for linear in linears: + # spinquant does this op in float64 + new_weight = linear.weight * norm.weight + update_offload_parameter(linear, "weight", new_weight) + + update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) + + else: + raise ValueError(f"Cannot fuse norm of type {type(norm)}") From f5c2150eefb3e87b1719ecc75b03de9a370bb94c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 11:07:36 -0400 Subject: [PATCH 10/35] R1 working Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/__init__.py | 1 + src/llmcompressor/modeling/fuse.py | 13 +++++++++---- src/llmcompressor/modifiers/transform/transform.py | 11 +++++++++-- src/llmcompressor/pipelines/data_free/pipeline.py | 5 +++++ 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index e2c22ed1f..871955916 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .prepare import * +from .fuse import * \ No newline at end of file diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 4a9a34bb3..a87914a8b 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -1,7 +1,9 @@ from typing import Iterable import torch -from compressed_tensors import update_offload_parameter +from compressed_tensors import get_execution_device, align_module_device, update_offload_parameter + +from transformers.models.llama.modeling_llama import LlamaRMSNorm __all__ = ["fuse_norm_linears"] @@ -16,10 +18,13 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) :param norm: norm layer whose weight will be fused into subsequent linears :param linears: linear layers which directly follow the norm layer """ - if isinstance(norm, torch.nn.RMSNorm): + if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)): for linear in linears: - # spinquant does this op in float64 - new_weight = linear.weight * norm.weight + # NOTE: spinquant does this op in float64 + exec_device = get_execution_device(norm) + with align_module_device(norm, exec_device), align_module_device(linear, exec_device): + new_weight = linear.weight * norm.weight + update_offload_parameter(linear, "weight", new_weight) update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index e94a3dc35..3c59cde03 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -4,6 +4,7 @@ from pydantic import ValidationError, model_validator from llmcompressor.core import Event, EventType, State +from llmcompressor.modeling import fuse_norm_linears from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS @@ -29,13 +30,19 @@ def validate_model_after(model: "TransformModifier") -> "TransformModifier": return model def on_initialize(self, state: State, **kwargs) -> bool: - apply_transform_config(state.model, self.config) - return True def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + for layer in state.model.model.layers: + fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) + fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) + + # needs to happen after the model has been hooked to execute on the GPU + # otherwise we're applying weight transforms on CPU + apply_transform_config(state.model, self.config) + def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: if not self.started_: diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py index 587f7ca69..7ad6d56dc 100644 --- a/src/llmcompressor/pipelines/data_free/pipeline.py +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -5,6 +5,7 @@ from llmcompressor.core.session_functions import LifecycleCallbacks from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.utils.dev import dispatch_for_generation if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -27,5 +28,9 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ + # some ops are still performed on the model by modifiers + # we want those ops to occur on the GPU + dispatch_for_generation(model) + LifecycleCallbacks.calibration_epoch_start() LifecycleCallbacks.calibration_epoch_end() From dc5c30c54df8a19bc9c928e51c648b533c505d4a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 11:38:49 -0400 Subject: [PATCH 11/35] add r2, increase precision Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 1 - src/llmcompressor/modeling/fuse.py | 7 +- .../modifiers/transform/presets/spinquant.py | 70 +++++++++++++++---- 3 files changed, 63 insertions(+), 15 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 1ec7b6516..96d65b997 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -59,7 +59,6 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - # TODO preset_config="LLAMA_SPINQUANT_R1R2" outputs gibberish # TODO preset_config="QUIP_ONLINE" outputs gibberish # preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index a87914a8b..3e059f7cb 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -23,7 +23,12 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) # NOTE: spinquant does this op in float64 exec_device = get_execution_device(norm) with align_module_device(norm, exec_device), align_module_device(linear, exec_device): - new_weight = linear.weight * norm.weight + + weight_dtype = linear.weight.dtype + + new_weight = linear.weight.to(torch.float64) * norm.weight.to(torch.float64) + + new_weight = new_weight.to(weight_dtype) update_offload_parameter(linear, "weight", new_weight) diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py index d9765a6d5..62dfb2477 100644 --- a/src/llmcompressor/modifiers/transform/presets/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -36,25 +36,69 @@ ), ], ), - # "R2": TransformScheme( - # type="hadamard", - # # TODO infer head_dim from config.json in SpinQuantModifier - # head_dim=128, - # apply=[ - # TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), - # TransformArgs( - # targets=["re:.*o_proj$"], - # location="weight_input", - # inverse=True, - # ), - # ], - # ), + "R2": TransformScheme( + type="hadamard", + # TODO infer head_dim from config.json in SpinQuantModifier + head_dim=128, + apply=[ + TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), + TransformArgs( + targets=["re:.*o_proj$"], + location="weight_input", + inverse=True, + ), + ], + ), } ) # All rotations LLAMA_SPINQUANT = TransformConfig( config_groups={ + "R1": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=[ + # outermost rotation + "re:.*embed_tokens$", + # attention rotations + "re:.*o_proj$", + # mlp rotations + "re:.*down_proj$", + ], + location="weight_output", + ), + TransformArgs( + targets=[ + # outermost rotation + "lm_head", + # attention rotations + "re:.*q_proj$", + "re:.*k_proj$", + "re:.*v_proj$", + # mlp rotations + "re:.*up_proj$", + "re:.*gate_proj$", + ], + location="weight_input", + inverse=True, + ), + ], + ), + "R2": TransformScheme( + type="hadamard", + # TODO infer head_dim from config.json in SpinQuantModifier + head_dim=128, + apply=[ + TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), + TransformArgs( + targets=["re:.*o_proj$"], + location="weight_input", + inverse=True, + ), + ], + ), # "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], # "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], "R3": TransformScheme( From 7172c2604f0301d05ec2be5cb4b1f58d49331d50 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:49:03 -0400 Subject: [PATCH 12/35] spinquant modifier Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 4 +- examples/transform/spinquant_dummy.py | 4 +- .../modifiers/transform/__init__.py | 3 +- .../modifiers/transform/presets/__init__.py | 9 - .../modifiers/transform/quip/base.py | 0 .../{presets/quip.py => quip/template.py} | 0 .../modifiers/transform/spinquant/__init__.py | 1 + .../modifiers/transform/spinquant/base.py | 215 ++++++++++++++++++ .../spinquant.py => spinquant/template.py} | 0 .../modifiers/transform/transform.py | 65 ------ 10 files changed, 221 insertions(+), 80 deletions(-) delete mode 100644 src/llmcompressor/modifiers/transform/presets/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/quip/base.py rename src/llmcompressor/modifiers/transform/{presets/quip.py => quip/template.py} (100%) create mode 100644 src/llmcompressor/modifiers/transform/spinquant/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/spinquant/base.py rename src/llmcompressor/modifiers/transform/{presets/spinquant.py => spinquant/template.py} (100%) delete mode 100644 src/llmcompressor/modifiers/transform/transform.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 96d65b997..8c87cb6a6 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -3,7 +3,7 @@ from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier -from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation # Select model and load it. @@ -62,7 +62,7 @@ def tokenize(sample): # TODO preset_config="QUIP_ONLINE" outputs gibberish # preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) - TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + SpinQuantModifier(rotations=["R1", "R2"]), # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] diff --git a/examples/transform/spinquant_dummy.py b/examples/transform/spinquant_dummy.py index 494b6c611..3e8c9d483 100644 --- a/examples/transform/spinquant_dummy.py +++ b/examples/transform/spinquant_dummy.py @@ -4,7 +4,7 @@ from compressed_tensors.utils import update_parameter_data from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier -from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, @@ -94,7 +94,7 @@ def forward(self, input_ids): recipe = [ # NOTE: preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) - TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + SpinQuantModifier(rotations=["R1", "R2"]), # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index c43958136..9956d0340 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa -from .presets import TRANSFORM_PRESETS -from .transform import TransformModifier +from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/presets/__init__.py b/src/llmcompressor/modifiers/transform/presets/__init__.py deleted file mode 100644 index 0d4a06b90..000000000 --- a/src/llmcompressor/modifiers/transform/presets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .quip import QUIP, QUIP_ONLINE -from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2 - -TRANSFORM_PRESETS = { - "QUIP": QUIP, - "QUIP_ONLINE": QUIP_ONLINE, - "LLAMA_SPINQUANT": LLAMA_SPINQUANT, - "LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2, -} diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llmcompressor/modifiers/transform/presets/quip.py b/src/llmcompressor/modifiers/transform/quip/template.py similarity index 100% rename from src/llmcompressor/modifiers/transform/presets/quip.py rename to src/llmcompressor/modifiers/transform/quip/template.py diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py new file mode 100644 index 000000000..773cfc466 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -0,0 +1 @@ +from .base import * \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py new file mode 100644 index 000000000..6d0c0cca3 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -0,0 +1,215 @@ +from typing import Optional, List, Literal + +from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config +from pydantic import BaseModel, field_validator, Field + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modeling import fuse_norm_linears +from llmcompressor.modifiers import Modifier +from enum import Enum + +from transformers import PreTrainedModel + + +class SpinQuantMappings(BaseModel): + embedding: str + + attn_q: str + attn_k: str + attn_v: str + attn_o: str + attn_head_dim: Optional[int] = Field(default=None) + + mlp_in: List[str] # up_proj, gate_proj + mlp_out: List[str] # down_proj + + lm_head: str + + @field_validator("mlp_in", "mlp_out", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + +class NormMapping(BaseModel): + norm: str + linears: List[str] + + @field_validator("linears", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + + +llama_spinquant = SpinQuantMappings( + embedding="re:.*embed_tokens$", + + attn_q="re:.*q_proj$", + attn_k="re:.*k_proj$", + attn_v="re:.*v_proj$", + attn_o="re:.*o_proj$", + + mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], + mlp_out="re:.*down_proj$", + + lm_head="lm_head", +) + +llama_norm_mappings = [ + NormMapping( + norm="re:.*input_layernorm$", + linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + NormMapping( + norm="re:.*post_attention_layernorm$", + linears=["re:.*up_proj$", "re:.*gate_proj$"], + ) +] + +class SpinquantRotation(Enum): + R1 = "R1" + R2 = "R2" + R3 = "R3" + R4 = "R4" + +class SpinQuantModifier(Modifier): + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") + randomize: bool = Field(default=False) + learnable: bool = Field(default=False) + + mappings: Optional[SpinQuantMappings] = None + norm_mappings: Optional[List[NormMapping]] = None + + transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control + + def on_initialize(self, state: State, **kwargs) -> bool: + # HARDCODE + self.mappings = llama_spinquant + self.norm_mappings = llama_norm_mappings + + if self.transform_config is not None: + if self.mappings is not None: + raise ValueError() + + return True + + config_groups = {} + for rotation in self.rotations: + if rotation == SpinquantRotation.R1: + config_groups["R1"] = self._create_r1_scheme() + + if rotation == SpinquantRotation.R2: + config_groups["R2"] = self._create_r2_scheme(state.model) + + if rotation == SpinquantRotation.R3: + config_groups["R3"] = self._create_r3_scheme() + + if rotation == SpinquantRotation.R4: + config_groups["R4"] = self._create_r4_scheme() + + self.transform_config = TransformConfig(config_groups=config_groups) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + for layer in state.model.model.layers: + fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) + fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) + + # needs to happen after the model has been hooked to execute on the GPU + # otherwise we're applying weight transforms on CPU + apply_transform_config(state.model, self.transform_config) + + + + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + + def _create_r1_scheme(self) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[ + self.mappings.embedding, + self.mappings.attn_o, + *self.mappings.mlp_out, + ], + location="weight_output", + ), + TransformArgs( + targets=[ + self.mappings.attn_q, + self.mappings.attn_k, + self.mappings.attn_v, + *self.mappings.mlp_in, + self.mappings.lm_head + ], + location="weight_input", + inverse=True, + ), + ] + ) + + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + head_dim=head_dim, + apply=[ + TransformArgs(targets=[self.mappings.attn_v], location="weight_output"), + TransformArgs( + targets=[self.mappings.attn_o], + location="weight_input", + inverse=True, + ), + ], + ) + + + def _create_r3_scheme(self) -> TransformScheme: + raise NotImplementedError() + + + def _create_r4_scheme(self) -> TransformScheme: + raise NotImplementedError() \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/spinquant/template.py similarity index 100% rename from src/llmcompressor/modifiers/transform/presets/spinquant.py rename to src/llmcompressor/modifiers/transform/spinquant/template.py diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py deleted file mode 100644 index 3c59cde03..000000000 --- a/src/llmcompressor/modifiers/transform/transform.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Optional - -from compressed_tensors.transform import TransformConfig, apply_transform_config -from pydantic import ValidationError, model_validator - -from llmcompressor.core import Event, EventType, State -from llmcompressor.modeling import fuse_norm_linears -from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS - - -class TransformModifier(Modifier): - preset_config: Optional[str] = None - config: Optional[TransformConfig] = None - - # model validator to validate both preset and config groups are not provided - @model_validator(mode="after") - def validate_model_after(model: "TransformModifier") -> "TransformModifier": - if model.preset_config is None and model.config is None: - raise ValidationError("Either a config or a preset_config must be provided") - - if model.preset_config is not None: - if model.preset_config not in TRANSFORM_PRESETS: - raise ValidationError( - f"Invalid preset_config '{model.preset_config}' " - f"must be in {TRANSFORM_PRESETS.keys()}" - ) - model.config = TRANSFORM_PRESETS[model.preset_config] - - return model - - def on_initialize(self, state: State, **kwargs) -> bool: - return True - - def on_start(self, state: State, event: Event, **kwargs): - self.started_ = True - - for layer in state.model.model.layers: - fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) - fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) - - # needs to happen after the model has been hooked to execute on the GPU - # otherwise we're applying weight transforms on CPU - apply_transform_config(state.model, self.config) - - def on_event(self, state: State, event: Event, **kwargs): - if event.type_ == EventType.CALIBRATION_EPOCH_START: - if not self.started_: - self.on_start(state, None) - - elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: - pass - - elif event.type_ == EventType.CALIBRATION_EPOCH_END: - if not self.ended_: - self.on_end(state, None) - - def on_end(self, state: State, event: Event, **kwargs): - self.ended_ = True - - def on_finalize(self, state: State, **kwargs) -> bool: - if not self.ended_: - self.on_end(state, None) - - return True From 9298e8268d8c6c11ebd7b6c4dd0a433c639f0971 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:50:00 -0400 Subject: [PATCH 13/35] remove space Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/spinquant/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 6d0c0cca3..8bf2e5cb1 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -78,7 +78,6 @@ class SpinquantRotation(Enum): class SpinQuantModifier(Modifier): rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) - transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") randomize: bool = Field(default=False) learnable: bool = Field(default=False) From f77226d12b8e6e7e5556f70b58b392d1b97d2025 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:51:20 -0400 Subject: [PATCH 14/35] use iterable Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/spinquant/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 8bf2e5cb1..76f38361c 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Literal +from typing import Optional, List, Literal, Iterable from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config from pydantic import BaseModel, field_validator, Field @@ -77,7 +77,7 @@ class SpinquantRotation(Enum): R4 = "R4" class SpinQuantModifier(Modifier): - rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + rotations: Iterable[SpinquantRotation] = ("R1", "R2") transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") randomize: bool = Field(default=False) learnable: bool = Field(default=False) From fdb64b54876f81c5e34fe020840334bc616ba6d6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:58:24 -0400 Subject: [PATCH 15/35] add rotation validation Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/spinquant/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 76f38361c..8e786fb7e 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -87,6 +87,12 @@ class SpinQuantModifier(Modifier): transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control + @field_validator("rotations", mode="before") + def validate_rotations(cls, value): + if isinstance(value, Iterable): + return tuple(v.upper() for v in value) + return value + def on_initialize(self, state: State, **kwargs) -> bool: # HARDCODE self.mappings = llama_spinquant From 5daa2d5a0cb31f32911942762a51c4ea69822f48 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 15:11:54 -0400 Subject: [PATCH 16/35] embedding fusion Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 8e786fb7e..813e1335a 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -125,6 +125,18 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + # TODO: use norm mappings + # Embedding fusion + # theoretically, doesn't do anything. Doesn't seem to help model sanity either + from compressed_tensors import update_offload_parameter + for W in [state.model.model.embed_tokens]: + W_ = W.weight.data.double() + W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) + + update_offload_parameter(state.model.model.embed_tokens, "weight", W.weight) + + # TODO: use norm mappings + # layer norm fusion for layer in state.model.model.layers: fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) From 0e9af7b6d1ff8d574c373b741aeeaf3733b4ee47 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 10:38:43 -0400 Subject: [PATCH 17/35] add missing norm fusion Signed-off-by: Kyle Sayers --- examples/transform/spinquant_dummy.py | 9 +-- src/llmcompressor/modeling/__init__.py | 2 +- src/llmcompressor/modeling/fuse.py | 18 +++-- .../modifiers/transform/spinquant/__init__.py | 2 +- .../modifiers/transform/spinquant/base.py | 75 +++++++++++-------- 5 files changed, 62 insertions(+), 44 deletions(-) diff --git a/examples/transform/spinquant_dummy.py b/examples/transform/spinquant_dummy.py index 3e8c9d483..71db967de 100644 --- a/examples/transform/spinquant_dummy.py +++ b/examples/transform/spinquant_dummy.py @@ -1,14 +1,13 @@ -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer import torch from compressed_tensors.utils import update_parameter_data +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.llama.modeling_llama import LlamaRMSNorm + from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation -from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, -) hidden_dim = intermediate_dim = 64 up_dim = 128 diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index 871955916..76b6b0391 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa +from .fuse import * from .prepare import * -from .fuse import * \ No newline at end of file diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 3e059f7cb..cb88ecc22 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -1,8 +1,11 @@ from typing import Iterable import torch -from compressed_tensors import get_execution_device, align_module_device, update_offload_parameter - +from compressed_tensors import ( + align_module_device, + get_execution_device, + update_offload_parameter, +) from transformers.models.llama.modeling_llama import LlamaRMSNorm __all__ = ["fuse_norm_linears"] @@ -22,14 +25,17 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) for linear in linears: # NOTE: spinquant does this op in float64 exec_device = get_execution_device(norm) - with align_module_device(norm, exec_device), align_module_device(linear, exec_device): - + with align_module_device(norm, exec_device), align_module_device( + linear, exec_device + ): weight_dtype = linear.weight.dtype - new_weight = linear.weight.to(torch.float64) * norm.weight.to(torch.float64) + new_weight = linear.weight.to(torch.float64) * norm.weight.to( + torch.float64 + ) new_weight = new_weight.to(weight_dtype) - + update_offload_parameter(linear, "weight", new_weight) update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py index 773cfc466..9b5ed21c9 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/__init__.py +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -1 +1 @@ -from .base import * \ No newline at end of file +from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 813e1335a..31b1bbdee 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,14 +1,18 @@ -from typing import Optional, List, Literal, Iterable +from enum import Enum +from typing import Iterable, List, Literal, Optional -from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config -from pydantic import BaseModel, field_validator, Field +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from pydantic import BaseModel, Field, field_validator +from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State from llmcompressor.modeling import fuse_norm_linears from llmcompressor.modifiers import Modifier -from enum import Enum - -from transformers import PreTrainedModel class SpinQuantMappings(BaseModel): @@ -29,9 +33,10 @@ class SpinQuantMappings(BaseModel): def cast_to_list(cls, value): if isinstance(value, str): return [value] - + return value - + + class NormMapping(BaseModel): norm: str linears: List[str] @@ -40,22 +45,18 @@ class NormMapping(BaseModel): def cast_to_list(cls, value): if isinstance(value, str): return [value] - - return value + return value llama_spinquant = SpinQuantMappings( embedding="re:.*embed_tokens$", - attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", attn_v="re:.*v_proj$", attn_o="re:.*o_proj$", - mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], mlp_out="re:.*down_proj$", - lm_head="lm_head", ) @@ -67,25 +68,31 @@ def cast_to_list(cls, value): NormMapping( norm="re:.*post_attention_layernorm$", linears=["re:.*up_proj$", "re:.*gate_proj$"], - ) + ), ] + class SpinquantRotation(Enum): R1 = "R1" R2 = "R2" R3 = "R3" R4 = "R4" + class SpinQuantModifier(Modifier): rotations: Iterable[SpinquantRotation] = ("R1", "R2") - transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="hadamard" + ) randomize: bool = Field(default=False) learnable: bool = Field(default=False) mappings: Optional[SpinQuantMappings] = None norm_mappings: Optional[List[NormMapping]] = None - - transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control + + transform_config: Optional[TransformConfig] = ( + None # optional override for more fine-grained control + ) @field_validator("rotations", mode="before") def validate_rotations(cls, value): @@ -101,7 +108,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: if self.transform_config is not None: if self.mappings is not None: raise ValueError() - + return True config_groups = {} @@ -129,6 +136,7 @@ def on_start(self, state: State, event: Event, **kwargs): # Embedding fusion # theoretically, doesn't do anything. Doesn't seem to help model sanity either from compressed_tensors import update_offload_parameter + for W in [state.model.model.embed_tokens]: W_ = W.weight.data.double() W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) @@ -138,16 +146,24 @@ def on_start(self, state: State, event: Event, **kwargs): # TODO: use norm mappings # layer norm fusion for layer in state.model.model.layers: - fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) - fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) + fuse_norm_linears( + layer.input_layernorm, + ( + layer.self_attn.q_proj, + layer.self_attn.k_proj, + layer.self_attn.v_proj, + ), + ) + fuse_norm_linears( + layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj) + ) + + fuse_norm_linears(state.model.model.norm, (state.model.lm_head,)) # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU apply_transform_config(state.model, self.transform_config) - - - def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: if not self.started_: @@ -169,7 +185,6 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( type=self.transform_type, @@ -190,14 +205,14 @@ def _create_r1_scheme(self) -> TransformScheme: self.mappings.attn_k, self.mappings.attn_v, *self.mappings.mlp_in, - self.mappings.lm_head + self.mappings.lm_head, ], location="weight_input", inverse=True, ), - ] + ], ) - + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: config = model.config @@ -207,7 +222,7 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: head_dim = config.hidden_size // config.num_attention_heads else: raise NotImplementedError() - + return TransformScheme( type=self.transform_type, randomize=self.randomize, @@ -223,10 +238,8 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ], ) - def _create_r3_scheme(self) -> TransformScheme: raise NotImplementedError() - def _create_r4_scheme(self) -> TransformScheme: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() From fce83be83f5c4ec01b1717263c1a6effcacf3e8d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 12:40:27 -0400 Subject: [PATCH 18/35] use norm mappings Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/fuse.py | 25 ++++++-- .../modifiers/transform/spinquant/base.py | 64 ++++++++++--------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index cb88ecc22..12e21f14b 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -8,7 +8,24 @@ ) from transformers.models.llama.modeling_llama import LlamaRMSNorm -__all__ = ["fuse_norm_linears"] +__all__ = ["normalize_embedding", "fuse_norm_linears"] + + +PRECISION = torch.float64 + + +def normalize_embedding(embedding: torch.nn.Module): + if isinstance(embedding, (torch.nn.Embedding)): + with align_module_device(embedding): + weight_dtype = embedding.weight.dtype + weight = embedding.weight.to(PRECISION) + new_weight = weight - weight.mean(dim=-1, keepdim=True) + new_weight = new_weight.to(weight_dtype) + + update_offload_parameter(embedding, "weight", new_weight) + + else: + raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): @@ -29,11 +46,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) linear, exec_device ): weight_dtype = linear.weight.dtype - - new_weight = linear.weight.to(torch.float64) * norm.weight.to( - torch.float64 - ) - + new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) new_weight = new_weight.to(weight_dtype) update_offload_parameter(linear, "weight", new_weight) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 31b1bbdee..c6b1c3087 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Iterable, List, Literal, Optional +from compressed_tensors import match_named_modules, is_match from compressed_tensors.transform import ( TransformArgs, TransformConfig, @@ -11,7 +12,7 @@ from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State -from llmcompressor.modeling import fuse_norm_linears +from llmcompressor.modeling import normalize_embedding, fuse_norm_linears from llmcompressor.modifiers import Modifier @@ -69,6 +70,10 @@ def cast_to_list(cls, value): norm="re:.*post_attention_layernorm$", linears=["re:.*up_proj$", "re:.*gate_proj$"], ), + NormMapping( + norm="model.norm", + linears=["lm_head"], + ), ] @@ -132,36 +137,10 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): self.started_ = True - # TODO: use norm mappings - # Embedding fusion - # theoretically, doesn't do anything. Doesn't seem to help model sanity either - from compressed_tensors import update_offload_parameter - - for W in [state.model.model.embed_tokens]: - W_ = W.weight.data.double() - W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) - - update_offload_parameter(state.model.model.embed_tokens, "weight", W.weight) - - # TODO: use norm mappings - # layer norm fusion - for layer in state.model.model.layers: - fuse_norm_linears( - layer.input_layernorm, - ( - layer.self_attn.q_proj, - layer.self_attn.k_proj, - layer.self_attn.v_proj, - ), - ) - fuse_norm_linears( - layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj) - ) - - fuse_norm_linears(state.model.model.norm, (state.model.lm_head,)) - # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU + self._prenormalize_embeddings(state.model) + self._fuse_norms(state.model) apply_transform_config(state.model, self.transform_config) def on_event(self, state: State, event: Event, **kwargs): @@ -185,6 +164,33 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True + def _prenormalize_embeddings(self, model: PreTrainedModel): + for _, embedding in match_named_modules( + model, [self.mappings.embedding], warn_on_fail=True + ): + normalize_embedding(embedding) + + def _fuse_norms(self, model: PreTrainedModel): + for mapping in self.norm_mappings: + targets = (mapping.norm, *mapping.linears) + matches = dict() + + for name, module in model.named_modules(): + # match until we get a full set + for target in targets: + if is_match(name, module, target): + if target in matches: + raise ValueError("Cannot match twice") + matches[target] = module + + # once we have a full set, fuse and reset + if all(target in matches for target in targets): + fuse_norm_linears( + matches[mapping.norm], + (matches[target] for target in mapping.linears), + ) + matches = dict() + def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( type=self.transform_type, From a979f8aff43ef81322d4b8934d03cb61fe65d360 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 13:02:09 -0400 Subject: [PATCH 19/35] break into separate files Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/__init__.py | 2 + .../modifiers/transform/spinquant/base.py | 83 ++++--------------- .../modifiers/transform/spinquant/mappings.py | 42 ++++++++++ .../transform/spinquant/norm_mappings.py | 35 ++++++++ 4 files changed, 93 insertions(+), 69 deletions(-) create mode 100644 src/llmcompressor/modifiers/transform/spinquant/mappings.py create mode 100644 src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py index 9b5ed21c9..8bdc93d14 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/__init__.py +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -1 +1,3 @@ +# flake8: noqa + from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index c6b1c3087..7c76aeca5 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,80 +1,22 @@ from enum import Enum from typing import Iterable, List, Literal, Optional -from compressed_tensors import match_named_modules, is_match +from compressed_tensors import is_match, match_named_modules from compressed_tensors.transform import ( TransformArgs, TransformConfig, TransformScheme, apply_transform_config, ) -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State -from llmcompressor.modeling import normalize_embedding, fuse_norm_linears +from llmcompressor.modeling import fuse_norm_linears, normalize_embedding from llmcompressor.modifiers import Modifier - -class SpinQuantMappings(BaseModel): - embedding: str - - attn_q: str - attn_k: str - attn_v: str - attn_o: str - attn_head_dim: Optional[int] = Field(default=None) - - mlp_in: List[str] # up_proj, gate_proj - mlp_out: List[str] # down_proj - - lm_head: str - - @field_validator("mlp_in", "mlp_out", mode="before") - def cast_to_list(cls, value): - if isinstance(value, str): - return [value] - - return value - - -class NormMapping(BaseModel): - norm: str - linears: List[str] - - @field_validator("linears", mode="before") - def cast_to_list(cls, value): - if isinstance(value, str): - return [value] - - return value - - -llama_spinquant = SpinQuantMappings( - embedding="re:.*embed_tokens$", - attn_q="re:.*q_proj$", - attn_k="re:.*k_proj$", - attn_v="re:.*v_proj$", - attn_o="re:.*o_proj$", - mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], - mlp_out="re:.*down_proj$", - lm_head="lm_head", -) - -llama_norm_mappings = [ - NormMapping( - norm="re:.*input_layernorm$", - linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], - ), - NormMapping( - norm="re:.*post_attention_layernorm$", - linears=["re:.*up_proj$", "re:.*gate_proj$"], - ), - NormMapping( - norm="model.norm", - linears=["lm_head"], - ), -] +from .mappings import SPINQUANT_MAPPING_REGISTRY, SpinQuantMappings +from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping class SpinquantRotation(Enum): @@ -92,12 +34,15 @@ class SpinQuantModifier(Modifier): randomize: bool = Field(default=False) learnable: bool = Field(default=False) + # norm mappings separate from spinquant mappings to allow users to + # override spinquant mappings with transform_config without overriding norms + # we can combine these mappings, but it requires some more validation logic + # maybe there's a reason to keep if other modifiers want norm fusing, idk mappings: Optional[SpinQuantMappings] = None norm_mappings: Optional[List[NormMapping]] = None - transform_config: Optional[TransformConfig] = ( - None # optional override for more fine-grained control - ) + # optional override for more fine-grained control + transform_config: Optional[TransformConfig] = None @field_validator("rotations", mode="before") def validate_rotations(cls, value): @@ -106,9 +51,9 @@ def validate_rotations(cls, value): return value def on_initialize(self, state: State, **kwargs) -> bool: - # HARDCODE - self.mappings = llama_spinquant - self.norm_mappings = llama_norm_mappings + # TODO: more validation + self.mappings = SPINQUANT_MAPPING_REGISTRY[state.model.__class__.__name__] + self.norm_mappings = NORM_MAPPING_REGISTRY[state.model.__class__.__name__] if self.transform_config is not None: if self.mappings is not None: diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py new file mode 100644 index 000000000..acf692d22 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -0,0 +1,42 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field, field_validator + + +class SpinQuantMappings(BaseModel): + embedding: str + + attn_q: str + attn_k: str + attn_v: str + attn_o: str + attn_head_dim: Optional[int] = Field(default=None) + + mlp_in: List[str] # up_proj, gate_proj + mlp_out: List[str] # down_proj + + lm_head: str + + @field_validator("mlp_in", "mlp_out", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = SpinQuantMappings( + embedding="re:.*embed_tokens$", + attn_q="re:.*q_proj$", + attn_k="re:.*k_proj$", + attn_v="re:.*v_proj$", + attn_o="re:.*o_proj$", + mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], + mlp_out="re:.*down_proj$", + lm_head="lm_head", +) + + +SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMappings] = { + "LlamaForCausalLM": _default_mappings, +} diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py new file mode 100644 index 000000000..cefb987ca --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -0,0 +1,35 @@ +from typing import Dict, List + +from pydantic import BaseModel, field_validator + + +class NormMapping(BaseModel): + norm: str + linears: List[str] + + @field_validator("linears", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_norm_mappings = [ + NormMapping( + norm="re:.*input_layernorm$", + linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + NormMapping( + norm="re:.*post_attention_layernorm$", + linears=["re:.*up_proj$", "re:.*gate_proj$"], + ), + NormMapping( + norm="model.norm", + linears=["lm_head"], + ), +] + +NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { + "LlamaForCausalLM": _default_norm_mappings, +} From 4cab29ef7060e6f67c43881fa44adeae2a0c4258 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 14:37:34 -0400 Subject: [PATCH 20/35] small cleanup Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 7c76aeca5..e448bd372 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -62,18 +62,17 @@ def on_initialize(self, state: State, **kwargs) -> bool: return True config_groups = {} - for rotation in self.rotations: - if rotation == SpinquantRotation.R1: - config_groups["R1"] = self._create_r1_scheme() + if SpinquantRotation.R1 in self.rotations: + config_groups["R1"] = self._create_r1_scheme() - if rotation == SpinquantRotation.R2: - config_groups["R2"] = self._create_r2_scheme(state.model) + if SpinquantRotation.R2 in self.rotations: + config_groups["R2"] = self._create_r2_scheme(state.model) - if rotation == SpinquantRotation.R3: - config_groups["R3"] = self._create_r3_scheme() + if SpinquantRotation.R3 in self.rotations: + config_groups["R3"] = self._create_r3_scheme() - if rotation == SpinquantRotation.R4: - config_groups["R4"] = self._create_r4_scheme() + if SpinquantRotation.R4 in self.rotations: + config_groups["R4"] = self._create_r4_scheme() self.transform_config = TransformConfig(config_groups=config_groups) From f1cc987c00163705b46e5ad286a0e87732196323 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 15:40:18 -0400 Subject: [PATCH 21/35] cleanup Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 32 ++++----- src/llmcompressor/entrypoints/oneshot.py | 3 +- .../modifiers/transform/test_dummy_model.py | 70 +++++++++---------- 3 files changed, 47 insertions(+), 58 deletions(-) rename examples/transform/spinquant_dummy.py => tests/llmcompressor/modifiers/transform/test_dummy_model.py (68%) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 8c87cb6a6..790619b08 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -2,13 +2,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation # Select model and load it. -# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" -# MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" # TODO hidden size 3072 causes failure when creating hadamard MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( @@ -57,36 +55,32 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) # Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - # TODO preset_config="QUIP_ONLINE" outputs gibberish - # preset_config="QUIP" output sensible, but cannot load saved - # checkpoint or run evals (~4hrs to run) - SpinQuantModifier(rotations=["R1", "R2"]), - # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + SpinQuantModifier(rotations=["R1", "R2"], transform_type="random-hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. oneshot( model=model, recipe=recipe, - # dataset=ds, - pipeline="datafree", - # max_seq_length=MAX_SEQUENCE_LENGTH, - # num_calibration_samples=NUM_CALIBRATION_SAMPLES, - log_dir=None, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# # Confirm generations of the quantized model look sane. +# Confirm generations of the quantized model look sane. print("\n\n") print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") output = model.generate(input_ids, max_new_tokens=100) print(tokenizer.decode(output[0])) -# print("==========================================\n\n") +print("==========================================\n\n") -# # Save to disk compressed. -# SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" -# model.save_pretrained(SAVE_DIR, save_compressed=True) -# tokenizer.save_pretrained(SAVE_DIR) +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index cfd3b551f..9219f21fb 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -125,8 +125,7 @@ def __init__( self.output_dir = output_dir # initialize the model and processor - # TODO Remove Comment before merge, this is just needed for DummyModel - # pre_process(model_args) + pre_process(model_args) # Set instance attributes self.model = self.model_args.model diff --git a/examples/transform/spinquant_dummy.py b/tests/llmcompressor/modifiers/transform/test_dummy_model.py similarity index 68% rename from examples/transform/spinquant_dummy.py rename to tests/llmcompressor/modifiers/transform/test_dummy_model.py index 71db967de..020a61e99 100644 --- a/examples/transform/spinquant_dummy.py +++ b/tests/llmcompressor/modifiers/transform/test_dummy_model.py @@ -14,9 +14,6 @@ num_embeddings = 12 -# TODO remove file before merging - - class DummySelfAttn(torch.nn.Module): def __init__(self, hidden_dim, intermediate_dim): super().__init__() @@ -75,37 +72,36 @@ def forward(self, input_ids): return self.lm_head(x) -model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) - -# TODO Uncomment this to see norm diff > 1e-6 -# This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 -# Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) -# https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 -# update_parameter_data( -# model.input_layernorm, -# torch.rand(model.input_layernorm.weight.shape), -# "weight", -# ) - -input_ids = torch.IntTensor([1, 2, 3, 4, 5]) -orig_output = model(input_ids) - -recipe = [ - # NOTE: preset_config="QUIP" output sensible, but cannot load saved - # checkpoint or run evals (~4hrs to run) - SpinQuantModifier(rotations=["R1", "R2"]), - # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), -] - -oneshot( - model=model, - recipe=recipe, - pipeline="datafree", - log_dir=None, -) - -# # Confirm generations of the quantized model look the same -transformed_output = model(input_ids) - -print(f"Norm Diff {(orig_output-transformed_output).norm()}") -print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") +def test_dummy_model(): + model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) + + # TODO Uncomment this to see norm diff > 1e-6 + # This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 + # Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) + # https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 + # update_parameter_data( + # model.input_layernorm, + # torch.rand(model.input_layernorm.weight.shape), + # "weight", + # ) + + input_ids = torch.IntTensor([1, 2, 3, 4, 5]) + orig_output = model(input_ids) + + recipe = [ + SpinQuantModifier(rotations=["R1", "R2"]), + ] + + # TODO: work around preprocessing? + oneshot( + model=model, + recipe=recipe, + pipeline="datafree", + log_dir=None, + ) + + # # Confirm generations of the quantized model look the same + transformed_output = model(input_ids) + + print(f"Norm Diff {(orig_output-transformed_output).norm()}") + print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") From a7bb2e2872cca3421e877de62bcb8a195f63a223 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 15:40:55 -0400 Subject: [PATCH 22/35] more cleanup Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/template.py | 130 ------------------ 1 file changed, 130 deletions(-) delete mode 100644 src/llmcompressor/modifiers/transform/spinquant/template.py diff --git a/src/llmcompressor/modifiers/transform/spinquant/template.py b/src/llmcompressor/modifiers/transform/spinquant/template.py deleted file mode 100644 index 62dfb2477..000000000 --- a/src/llmcompressor/modifiers/transform/spinquant/template.py +++ /dev/null @@ -1,130 +0,0 @@ -from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme - -# Ref: https://arxiv.org/pdf/2405.16406 Fig 1 - -# Mergeable rotations R1 and R2 only -LLAMA_SPINQUANT_R1R2 = TransformConfig( - config_groups={ - "R1": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=[ - # outermost rotation - "re:.*embed_tokens$", - # attention rotations - "re:.*o_proj$", - # mlp rotations - "re:.*down_proj$", - ], - location="weight_output", - ), - TransformArgs( - targets=[ - # outermost rotation - "lm_head", - # attention rotations - "re:.*q_proj$", - "re:.*k_proj$", - "re:.*v_proj$", - # mlp rotations - "re:.*up_proj$", - "re:.*gate_proj$", - ], - location="weight_input", - inverse=True, - ), - ], - ), - "R2": TransformScheme( - type="hadamard", - # TODO infer head_dim from config.json in SpinQuantModifier - head_dim=128, - apply=[ - TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), - TransformArgs( - targets=["re:.*o_proj$"], - location="weight_input", - inverse=True, - ), - ], - ), - } -) - -# All rotations -LLAMA_SPINQUANT = TransformConfig( - config_groups={ - "R1": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=[ - # outermost rotation - "re:.*embed_tokens$", - # attention rotations - "re:.*o_proj$", - # mlp rotations - "re:.*down_proj$", - ], - location="weight_output", - ), - TransformArgs( - targets=[ - # outermost rotation - "lm_head", - # attention rotations - "re:.*q_proj$", - "re:.*k_proj$", - "re:.*v_proj$", - # mlp rotations - "re:.*up_proj$", - "re:.*gate_proj$", - ], - location="weight_input", - inverse=True, - ), - ], - ), - "R2": TransformScheme( - type="hadamard", - # TODO infer head_dim from config.json in SpinQuantModifier - head_dim=128, - apply=[ - TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), - TransformArgs( - targets=["re:.*o_proj$"], - location="weight_input", - inverse=True, - ), - ], - ), - # "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], - # "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], - "R3": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*self_attn$"], - location="k_cache", - ), - TransformArgs( - targets=["re:.*self_attn$"], - location="q_attn", - ), - ], - ), - "R4": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*down_proj$"], - location="input", - ), - TransformArgs( - targets=["re:.*down_proj$"], location="weight_input", inverse=True - ), - ], - ), - } -) From 0cf0188987898587c6f5d96a53b97264c8ee0435 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 15:48:49 -0400 Subject: [PATCH 23/35] make new weight on cpu Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/fuse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 12e21f14b..33e91601c 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -51,7 +51,8 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) update_offload_parameter(linear, "weight", new_weight) - update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) + new_norm_weight = torch.ones_like(norm.weight, device="cpu") + update_offload_parameter(norm, "weight", new_norm_weight) else: raise ValueError(f"Cannot fuse norm of type {type(norm)}") From 53ea3076161f8562fe7653f9f6cb57c48da75ae4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 16:35:55 -0400 Subject: [PATCH 24/35] standardize, make modifier serializable Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 2 +- .../modifiers/transform/quip/base.py | 0 .../modifiers/transform/quip/template.py | 98 ------------------- .../modifiers/transform/spinquant/base.py | 13 +-- src/llmcompressor/pipelines/registry.py | 5 + 5 files changed, 13 insertions(+), 105 deletions(-) delete mode 100644 src/llmcompressor/modifiers/transform/quip/base.py delete mode 100644 src/llmcompressor/modifiers/transform/quip/template.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 790619b08..876db7138 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -58,7 +58,7 @@ def tokenize(sample): # * apply spinquant transforms to model in order to make quantization easier # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - SpinQuantModifier(rotations=["R1", "R2"], transform_type="random-hadamard"), + SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/llmcompressor/modifiers/transform/quip/template.py b/src/llmcompressor/modifiers/transform/quip/template.py deleted file mode 100644 index 4ce5e47ae..000000000 --- a/src/llmcompressor/modifiers/transform/quip/template.py +++ /dev/null @@ -1,98 +0,0 @@ -from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme - -QUIP = TransformConfig( - config_groups={ - "v": TransformScheme( - type="random-hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="input", # non-mergable - ignore="lm_head", - ), - TransformArgs( - targets=["Linear"], - location="weight_input", - inverse=True, - ignore="lm_head", - ), - ], - randomize=True, - ), - "u": TransformScheme( - type="random-hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="weight_output", - ignore="lm_head", - ), - TransformArgs( - targets=["Linear"], - location="output", # non-mergable - inverse=True, - ignore="lm_head", - ), - ], - randomize=True, - ), - } -) - -# https://github.com/vllm-project/llm-compressor/blob/b43b27a2f277a5e62be4f8c713b84fd1c7aa116b/weight_transform.py#L24-L105 -QUIP_ONLINE = TransformConfig( - config_groups={ - "u_transform_q_o_down_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=[ - "re:.*.attn.q_proj$", - "re:.*.attn.o_proj$", - "re:.*.mlp.down_proj$", - ], - location="weight_input", - ) - ], - ), - "u_transform_k_v_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"], - location="weight_input", - ) - ], - ), - "u_transform_gate_up_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"], - location="weight_input", - ) - ], - ), - "v_transform_linear": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="weight_output", - ignore=["re:.*.mlp.down_proj$", "lm_head"], - inverse=True, - ) - ], - ), - "v_transform_down_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*.mlp.down_proj$"], - location="weight_output", - inverse=True, - ) - ], - ), - } -) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index e448bd372..5997fac19 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -19,15 +19,15 @@ from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping -class SpinquantRotation(Enum): +class SpinquantRotation(str, Enum): R1 = "R1" R2 = "R2" R3 = "R3" R4 = "R4" -class SpinQuantModifier(Modifier): - rotations: Iterable[SpinquantRotation] = ("R1", "R2") +class SpinQuantModifier(Modifier, use_enum_values=True): + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="hadamard" ) @@ -38,11 +38,12 @@ class SpinQuantModifier(Modifier): # override spinquant mappings with transform_config without overriding norms # we can combine these mappings, but it requires some more validation logic # maybe there's a reason to keep if other modifiers want norm fusing, idk - mappings: Optional[SpinQuantMappings] = None - norm_mappings: Optional[List[NormMapping]] = None + mappings: Optional[SpinQuantMappings] = Field(default=None, exclude=True) + norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True) # optional override for more fine-grained control - transform_config: Optional[TransformConfig] = None + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None) @field_validator("rotations", mode="before") def validate_rotations(cls, value): diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 2c1a54cf5..98fb836b0 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -8,6 +8,7 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -60,5 +61,9 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str: config = modifiers[0].resolve_quantization_config() if not config.requires_calibration_data(): return "datafree" + + # TODO: Remove hardcode + if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier): + return "datafree" return "sequential" From 4b4257fe871df0f10b13e8ab9ee16f058a8456ed Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 16:50:10 -0400 Subject: [PATCH 25/35] add compress model script Signed-off-by: Kyle Sayers --- compress_model.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 compress_model.py diff --git a/compress_model.py b/compress_model.py new file mode 100644 index 000000000..fa67bead0 --- /dev/null +++ b/compress_model.py @@ -0,0 +1,60 @@ +# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, help="Model stub to compress") + parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") + parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Select model and load it. + MODEL_ID = args.model_id + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + # Select number of samples. 512 samples is a good place to start. + # Increasing the number of samples can improve accuracy. + NUM_CALIBRATION_SAMPLES = 512 + MAX_SEQUENCE_LENGTH = 2048 + + # Configure the quantization algorithm to run. + recipe = [] + if args.transform_type: + recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) + + if args.scheme: + recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) + + # Apply algorithms. + oneshot( + model=model, + recipe=recipe, + dataset="ultrachat_200k", + splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + ) + + # Confirm generations of the quantized model look sane. + print("\n\n") + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) + print("==========================================\n\n") + + # Save to disk compressed. + SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" + model.save_pretrained(SAVE_DIR, save_compressed=True) + tokenizer.save_pretrained(SAVE_DIR) From dc7ac1a1e4a94c8402f003d90eaa5a75dccabb21 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 11:08:09 -0400 Subject: [PATCH 26/35] use untie_word_embeddings Signed-off-by: Kyle Sayers --- src/llmcompressor/entrypoints/utils.py | 7 ++- .../compressed_tensors_utils.py | 50 +++++++++---------- .../test_compress_tensor_utils.py | 42 ++++++---------- 3 files changed, 43 insertions(+), 56 deletions(-) diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py index 5647e4d06..95ec832fb 100644 --- a/src/llmcompressor/entrypoints/utils.py +++ b/src/llmcompressor/entrypoints/utils.py @@ -20,7 +20,7 @@ from llmcompressor.pytorch.model_load.helpers import parse_dtype from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, - patch_tied_tensors_bug, + untie_word_embeddings, ) from llmcompressor.transformers.utils.helpers import ( detect_last_checkpoint, @@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"): ) # untie tie_word_embeddings weights - patch_tied_tensors_bug(model_args.model) + if not model_args.tie_word_embeddings: + untie_word_embeddings(model_args.model) # wrap model.save_pretrained modify_save_pretrained(model_args.model) @@ -143,7 +144,6 @@ def initialize_model_from_path( cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) @@ -156,7 +156,6 @@ def initialize_model_from_path( AutoConfig.from_pretrained( model_args.distill_teacher, use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) if model_args.distill_teacher diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 69b0e3f28..0fdaa9dc6 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -9,8 +9,9 @@ CompressionFormat, ModelCompressor, SparsityCompressionConfig, + delete_offload_parameter, is_module_offloaded, - update_offload_parameter, + register_offload_parameter, ) from loguru import logger from safetensors.torch import storage_ptr @@ -27,7 +28,7 @@ from llmcompressor.transformers.utils import RECIPE_FILE_NAME from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path -__all__ = ["modify_save_pretrained"] +__all__ = ["modify_save_pretrained", "untie_word_embeddings"] def modify_save_pretrained(model: PreTrainedModel): @@ -120,7 +121,7 @@ def save_pretrained_wrapper( model.save_pretrained = save_pretrained_compressed(model.save_pretrained) -def patch_tied_tensors_bug(model: torch.nn.Module): +def untie_word_embeddings(model: PreTrainedModel): """ Patches bug where HF transformers will fail to untie weights under specific circumstances (https://github.com/huggingface/transformers/issues/33689). @@ -129,28 +130,27 @@ def patch_tied_tensors_bug(model: torch.nn.Module): :param model: model to fix """ - if ( - hasattr(model.config, "tie_word_embeddings") - and not model.config.tie_word_embeddings - ): - input_embed = model.get_input_embeddings() - output_embed = model.get_output_embeddings() - - if input_embed is None or output_embed is None: - # some models fail to properly override the abstract methods - return - - if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight): - for module in (input_embed, output_embed): - if not is_module_offloaded(module): - # create new storage ptr for onloaded weight - untied_data = module.weight.data.clone() - module.weight.data = untied_data - else: - # create new storage ptr for offloaded weight - # note `update_offload_parameter` does not create a new storage ptr - untied_data = module._hf_hook.weights_map["weight"].clone() - update_offload_parameter(module, "weight", untied_data) + input_embed = model.get_input_embeddings() + output_embed = model.get_output_embeddings() + + for module in (input_embed, output_embed): + if module is None or not hasattr(module, "weight"): + logger.warning(f"Cannot untie {module} which does not have weight param") + continue + + # this could be replaced by a `get_offloaded_parameter` util + if not is_module_offloaded(module): + untied_data = module.weight.data.clone() + else: + untied_data = module._hf_hook.weights_map["weight"].clone() + + requires_grad = module.weight.requires_grad + new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad) + delete_offload_parameter(module, "weight") + register_offload_parameter(module, "weight", new_parameter) + + if hasattr(model.config, "tie_word_embeddings"): + model.config.tie_word_embeddings = False def get_model_compressor( diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index 140e706d1..aad551ff8 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -28,7 +28,7 @@ from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( get_model_compressor, modify_save_pretrained, - patch_tied_tensors_bug, + untie_word_embeddings, ) from tests.testing_utils import requires_gpu @@ -224,8 +224,6 @@ def test_quant_model_reload(format, dtype, tmp_path): shutil.rmtree(tmp_path) -# technically only tie_word_embeddings=False is supported right now -# setting to True is discouraged @pytest.mark.parametrize( "offload,torch_dtype,tie_word_embeddings,device", [ @@ -237,25 +235,23 @@ def test_quant_model_reload(format, dtype, tmp_path): # offloading (True, torch.float16, False, "cpu"), (True, torch.float32, False, "cpu"), - # (True, torch.float16, True, "cpu"), # TODO: fails - # (True, torch.float32, True, "cpu"), # TODO: fails + (True, torch.float16, True, "cpu"), + (True, torch.float32, True, "cpu"), ], ) def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path): model_path = "nm-testing/llama2.c-stories15M" save_path = tmp_path / "save_path" - model = AutoModelForCausalLM.from_pretrained( - model_path, - tie_word_embeddings=tie_word_embeddings, - torch_dtype=torch_dtype, - ) + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) if offload: model = dispatch_model(model, {"": device}, force_hooks=True) else: model = model.to(device) - patch_tied_tensors_bug(model) + if not tie_word_embeddings: + untie_word_embeddings(model) + modify_save_pretrained(model) model.save_pretrained(save_path, safe_serialization=True) @@ -294,22 +290,18 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp (True, torch.float32, True, "cpu"), ], ) -def test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device, tmp_path -): +def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device): # load model - model = AutoModelForCausalLM.from_pretrained( - "nm-testing/llama2.c-stories15M", - torch_dtype=torch_dtype, - tie_word_embeddings=tie_word_embeddings, - ) - patch_tied_tensors_bug(model) - + model_path = "nm-testing/llama2.c-stories15M" + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) if offload: model = dispatch_model(model, {"": device}, force_hooks=True) else: model = model.to(device) + if not tie_word_embeddings: + untie_word_embeddings(model) + # modify lm head with torch.no_grad(), align_module_device(model.lm_head): update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) @@ -332,12 +324,8 @@ def test_model_shared_tensors( (False, torch.float32, True, "cuda:0"), ], ) -def test_model_shared_tensors_gpu( - offload, torch_dtype, tie_word_embeddings, device, tmp_path -): - test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device, tmp_path - ) +def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device): + test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device) @requires_gpu From 8542f8d1ea21f78338f7b9ca6e1df5b49c9d8232 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 11:08:44 -0400 Subject: [PATCH 27/35] style Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 98fb836b0..67d510d13 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -61,7 +61,7 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str: config = modifiers[0].resolve_quantization_config() if not config.requires_calibration_data(): return "datafree" - + # TODO: Remove hardcode if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier): return "datafree" From b1e637eb88f0b9d8c5524a836d99c0baade0a54f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 12:46:13 -0400 Subject: [PATCH 28/35] better registery logic Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 12 +++++------ .../modifiers/transform/spinquant/mappings.py | 21 ++++++++++++++++--- .../transform/spinquant/norm_mappings.py | 19 +++++++++++++++-- .../compressed_tensors_utils.py | 1 - 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 5997fac19..c8376a6a0 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -15,8 +15,8 @@ from llmcompressor.modeling import fuse_norm_linears, normalize_embedding from llmcompressor.modifiers import Modifier -from .mappings import SPINQUANT_MAPPING_REGISTRY, SpinQuantMappings -from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping +from .mappings import SpinQuantMapping, infer_mapping_from_model +from .norm_mappings import NormMapping, infer_norm_mapping_from_model class SpinquantRotation(str, Enum): @@ -36,9 +36,7 @@ class SpinQuantModifier(Modifier, use_enum_values=True): # norm mappings separate from spinquant mappings to allow users to # override spinquant mappings with transform_config without overriding norms - # we can combine these mappings, but it requires some more validation logic - # maybe there's a reason to keep if other modifiers want norm fusing, idk - mappings: Optional[SpinQuantMappings] = Field(default=None, exclude=True) + mappings: Optional[SpinQuantMapping] = Field(default=None, exclude=True) norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True) # optional override for more fine-grained control @@ -53,8 +51,8 @@ def validate_rotations(cls, value): def on_initialize(self, state: State, **kwargs) -> bool: # TODO: more validation - self.mappings = SPINQUANT_MAPPING_REGISTRY[state.model.__class__.__name__] - self.norm_mappings = NORM_MAPPING_REGISTRY[state.model.__class__.__name__] + self.mappings = infer_mapping_from_model(state.model) + self.norm_mappings = infer_norm_mapping_from_model(state.model) if self.transform_config is not None: if self.mappings is not None: diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index acf692d22..7dc327b78 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -1,9 +1,13 @@ from typing import Dict, List, Optional +from loguru import logger from pydantic import BaseModel, Field, field_validator +from transformers import PreTrainedModel +__all__ = ["SpinQuantMapping", "infer_mapping_from_model"] -class SpinQuantMappings(BaseModel): + +class SpinQuantMapping(BaseModel): embedding: str attn_q: str @@ -25,7 +29,7 @@ def cast_to_list(cls, value): return value -_default_mappings = SpinQuantMappings( +_default_mappings = SpinQuantMapping( embedding="re:.*embed_tokens$", attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", @@ -37,6 +41,17 @@ def cast_to_list(cls, value): ) -SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMappings] = { +SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = { "LlamaForCausalLM": _default_mappings, } + + +def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping: + architecture = model.__class__.__name__ + if architecture not in SPINQUANT_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py index cefb987ca..0752f6986 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -1,6 +1,10 @@ from typing import Dict, List +from loguru import logger from pydantic import BaseModel, field_validator +from transformers import PreTrainedModel + +__all__ = ["infer_norm_mapping_from_model"] class NormMapping(BaseModel): @@ -15,7 +19,7 @@ def cast_to_list(cls, value): return value -_default_norm_mappings = [ +_default_mappings = [ NormMapping( norm="re:.*input_layernorm$", linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], @@ -31,5 +35,16 @@ def cast_to_list(cls, value): ] NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { - "LlamaForCausalLM": _default_norm_mappings, + "LlamaForCausalLM": _default_mappings, } + + +def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]: + architecture = model.__class__.__name__ + if architecture not in NORM_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return NORM_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 0fdaa9dc6..1495f6d06 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -14,7 +14,6 @@ register_offload_parameter, ) from loguru import logger -from safetensors.torch import storage_ptr from transformers import PreTrainedModel from llmcompressor.core import active_session From b44ac817b65dec264146c849d67566de5b38cc37 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 13:05:05 -0400 Subject: [PATCH 29/35] remove dummy model test (add later) Signed-off-by: Kyle Sayers --- .../modifiers/transform/test_dummy_model.py | 107 ------------------ 1 file changed, 107 deletions(-) delete mode 100644 tests/llmcompressor/modifiers/transform/test_dummy_model.py diff --git a/tests/llmcompressor/modifiers/transform/test_dummy_model.py b/tests/llmcompressor/modifiers/transform/test_dummy_model.py deleted file mode 100644 index 020a61e99..000000000 --- a/tests/llmcompressor/modifiers/transform/test_dummy_model.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch -from compressed_tensors.utils import update_parameter_data -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.models.llama.modeling_llama import LlamaRMSNorm - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier -from llmcompressor.modifiers.transform import SpinQuantModifier -from llmcompressor.utils import dispatch_for_generation - -hidden_dim = intermediate_dim = 64 -up_dim = 128 -num_embeddings = 12 - - -class DummySelfAttn(torch.nn.Module): - def __init__(self, hidden_dim, intermediate_dim): - super().__init__() - self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) - self.k_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) - self.v_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) - self.o_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) - self.num_heads = 1 - self.num_key_value_groups = 1 - - def forward(self, hidden_states): - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - ### EAGER ATTENTION - attn_weights = torch.matmul(q.T, k) - - attn_weights = torch.nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(q.dtype) - attn_output = torch.matmul(attn_weights, v.T) - attn_output = attn_output.T.contiguous() - - return self.o_proj(attn_output) - - -class DummyMLP(torch.nn.Module): - def __init__(self, hidden_dim, up_dim): - super().__init__() - self.up_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) - self.gate_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) - self.down_proj = torch.nn.Linear(up_dim, hidden_dim, bias=None) - self.act_fn = torch.nn.SiLU() - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class DummyModel(torch.nn.Module): - def __init__(self, num_embeddings, hidden_dim, intermediate_dim, up_dim): - super().__init__() - self.embed_tokens = torch.nn.Embedding(num_embeddings, hidden_dim) - self.input_layernorm = LlamaRMSNorm(hidden_dim) - self.post_attention_layernorm = LlamaRMSNorm(hidden_dim) - self.self_attn = DummySelfAttn(hidden_dim, intermediate_dim) - self.mlp = DummyMLP(hidden_dim, up_dim) - self.lm_head = torch.nn.Linear(hidden_dim, num_embeddings, bias=None) - - def forward(self, input_ids): - x = self.embed_tokens(input_ids) - x = self.input_layernorm(x) - x = self.self_attn(x) - x = self.post_attention_layernorm(x) - x = self.mlp(x) - return self.lm_head(x) - - -def test_dummy_model(): - model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) - - # TODO Uncomment this to see norm diff > 1e-6 - # This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 - # Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) - # https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 - # update_parameter_data( - # model.input_layernorm, - # torch.rand(model.input_layernorm.weight.shape), - # "weight", - # ) - - input_ids = torch.IntTensor([1, 2, 3, 4, 5]) - orig_output = model(input_ids) - - recipe = [ - SpinQuantModifier(rotations=["R1", "R2"]), - ] - - # TODO: work around preprocessing? - oneshot( - model=model, - recipe=recipe, - pipeline="datafree", - log_dir=None, - ) - - # # Confirm generations of the quantized model look the same - transformed_output = model(input_ids) - - print(f"Norm Diff {(orig_output-transformed_output).norm()}") - print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") From 7a52b710b73682119c45f61669f92b5ac6e0b189 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 13:34:18 -0400 Subject: [PATCH 30/35] docstring Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index c8376a6a0..5a1ea7844 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -8,7 +8,7 @@ TransformScheme, apply_transform_config, ) -from pydantic import Field, field_validator +from pydantic import Field, ValidationInfo, field_validator from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State @@ -27,6 +27,37 @@ class SpinquantRotation(str, Enum): class SpinQuantModifier(Modifier, use_enum_values=True): + """ + Implements the transforms according to + [SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406) # noqa: E501 + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achived through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + The SpinQuant authors describe four different rotations which can be applied to a + model. R1 and R2 are "offline" rotations, meaning that they can be fused into + existing weights and therefore do not induce runtime cost. R3 and R4 are "online" + rotations, meaning that they require additional computation at runtime. + + :param rotations: A list containing the names of rotations to apply to the model. + Possible rotations include R1, R2, R3, and R4 + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-matrix"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: if True, create distinct transforms for each application + :param learnable: if True, attach gradients to transform weights for training + :param mappings: Specifies layers within a model to target for transforms. + A mapping will be inferred if None is provided + :param norm_mappings: Specifies layers within a model to target for norm fusing. + A mapping will be inferred if None is provided + :param transform_config: Optional transform config which overrides `mappings` + """ + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="hadamard" @@ -43,6 +74,10 @@ class SpinQuantModifier(Modifier, use_enum_values=True): # also included in recipe serialization transform_config: Optional[TransformConfig] = Field(default=None) + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + raise NotImplementedError(f"{info.field_name} is not supported right now") + @field_validator("rotations", mode="before") def validate_rotations(cls, value): if isinstance(value, Iterable): @@ -50,16 +85,18 @@ def validate_rotations(cls, value): return value def on_initialize(self, state: State, **kwargs) -> bool: - # TODO: more validation - self.mappings = infer_mapping_from_model(state.model) - self.norm_mappings = infer_norm_mapping_from_model(state.model) - if self.transform_config is not None: if self.mappings is not None: - raise ValueError() + raise ValueError( + "Please provide either `transform_config` or `mappings` " + "but not both" + ) return True + self.mappings = infer_mapping_from_model(state.model) + self.norm_mappings = infer_norm_mapping_from_model(state.model) + config_groups = {} if SpinquantRotation.R1 in self.rotations: config_groups["R1"] = self._create_r1_scheme() From f4d7ec6d807c629a264cc90b3fec13d1b281e242 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 15:02:11 -0400 Subject: [PATCH 31/35] update docstring Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 5a1ea7844..2bf593635 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -55,24 +55,34 @@ class SpinQuantModifier(Modifier, use_enum_values=True): A mapping will be inferred if None is provided :param norm_mappings: Specifies layers within a model to target for norm fusing. A mapping will be inferred if None is provided - :param transform_config: Optional transform config which overrides `mappings` + :param transform_config: Optional transform config for overriding provided arguments """ - rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + rotations: List[SpinquantRotation] = Field( + default_factory=lambda: ["R1", "R2"], exclude=True + ) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( - default="hadamard" + default="hadamard", exclude=True ) - randomize: bool = Field(default=False) - learnable: bool = Field(default=False) + randomize: bool = Field(default=False, exclude=True) + learnable: bool = Field(default=False, exclude=True) # norm mappings separate from spinquant mappings to allow users to # override spinquant mappings with transform_config without overriding norms - mappings: Optional[SpinQuantMapping] = Field(default=None, exclude=True) - norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True) + mappings: Optional[SpinQuantMapping] = Field( + default=None, + repr=False, + exclude=True, + ) + norm_mappings: Optional[List[NormMapping]] = Field( + default=None, + repr=False, + exclude=True, + ) # optional override for more fine-grained control # also included in recipe serialization - transform_config: Optional[TransformConfig] = Field(default=None) + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) @field_validator("randomize", "learnable", mode="before") def validate_not_implemented(cls, value, info: ValidationInfo): @@ -86,12 +96,6 @@ def validate_rotations(cls, value): def on_initialize(self, state: State, **kwargs) -> bool: if self.transform_config is not None: - if self.mappings is not None: - raise ValueError( - "Please provide either `transform_config` or `mappings` " - "but not both" - ) - return True self.mappings = infer_mapping_from_model(state.model) From f18d0e894d984d6ec9207f9fe71e6533669c8aa3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 15:07:27 -0400 Subject: [PATCH 32/35] rename example file Signed-off-by: Kyle Sayers --- examples/transform/{llama3_example.py => spinquant_example.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/transform/{llama3_example.py => spinquant_example.py} (100%) diff --git a/examples/transform/llama3_example.py b/examples/transform/spinquant_example.py similarity index 100% rename from examples/transform/llama3_example.py rename to examples/transform/spinquant_example.py From cec2914342ad337be13fccff29ca7426d713c0ec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 11:13:48 -0400 Subject: [PATCH 33/35] use match_modules_set Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 2bf593635..5978b93ea 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Iterable, List, Literal, Optional -from compressed_tensors import is_match, match_named_modules +from compressed_tensors import match_modules_set, match_named_modules from compressed_tensors.transform import ( TransformArgs, TransformConfig, @@ -156,24 +156,10 @@ def _prenormalize_embeddings(self, model: PreTrainedModel): def _fuse_norms(self, model: PreTrainedModel): for mapping in self.norm_mappings: - targets = (mapping.norm, *mapping.linears) - matches = dict() - - for name, module in model.named_modules(): - # match until we get a full set - for target in targets: - if is_match(name, module, target): - if target in matches: - raise ValueError("Cannot match twice") - matches[target] = module - - # once we have a full set, fuse and reset - if all(target in matches for target in targets): - fuse_norm_linears( - matches[mapping.norm], - (matches[target] for target in mapping.linears), - ) - matches = dict() + for norm, *linears in match_modules_set( + model, (mapping.norm, *mapping.linears) + ): + fuse_norm_linears(norm, linears) def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( From 0c5c514313d887caf715aa9f14bdb35f50e3bad6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 17 Jul 2025 14:12:55 -0400 Subject: [PATCH 34/35] unit test fixes Signed-off-by: Brian Dellabetta --- src/llmcompressor/modeling/fuse.py | 3 ++- tests/llmcompressor/modeling/test_fuse.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 40dc31e6a..e59be596c 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -32,6 +32,7 @@ def normalize_embedding(embedding: torch.nn.Module): else: raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") + def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): """ Fuse a norm layer into subsequent linear layers. This useful for ensuring transform @@ -42,7 +43,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) :param norm: norm layer whose weight will be fused into subsequent linears :param linears: linear layers which directly follow the norm layer """ - if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)): + if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm, torch.nn.LayerNorm)): for linear in linears: # NOTE: spinquant does this op in float64 exec_device = get_execution_device(norm) diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py index 005d89f99..f85cd68dc 100644 --- a/tests/llmcompressor/modeling/test_fuse.py +++ b/tests/llmcompressor/modeling/test_fuse.py @@ -1,13 +1,13 @@ import pytest import torch -from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears +from llmcompressor.modeling.fuse import normalize_embedding, fuse_norm_linears @pytest.mark.unit -def test_center_embeddings(): +def test_normalize_embedding(): embedding = torch.nn.Embedding(10, 10) - center_embeddings(embedding) + normalize_embedding(embedding) assert torch.allclose( embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5 From f2ef7cfd5734434b285c44ac43c8f108ba9afae1 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 17 Jul 2025 14:16:26 -0400 Subject: [PATCH 35/35] style fixes Signed-off-by: Brian Dellabetta --- tests/llmcompressor/modeling/test_fuse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py index f85cd68dc..5798f692c 100644 --- a/tests/llmcompressor/modeling/test_fuse.py +++ b/tests/llmcompressor/modeling/test_fuse.py @@ -1,7 +1,7 @@ import pytest import torch -from llmcompressor.modeling.fuse import normalize_embedding, fuse_norm_linears +from llmcompressor.modeling.fuse import fuse_norm_linears, normalize_embedding @pytest.mark.unit