Skip to content

[Tests] Spinquant dummy model tests #1647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba617db
wip
kylesayrs Jun 6, 2025
2f5b1c8
use random-hadamard, add correctness tests
kylesayrs Jun 12, 2025
3aa35e7
add correctness test, note that precision makes a large difference
kylesayrs Jun 12, 2025
b6c088e
add on lifecycle methods
brian-dellabetta Jun 23, 2025
d1eb2a1
Merge branch 'main' into kylesayrs/transform-modifier
brian-dellabetta Jul 1, 2025
3207124
TransformModifier with SpinQuant R1&R2
brian-dellabetta Jul 2, 2025
a88ca3c
spinquant and quip_online, running but outputting gibberish
brian-dellabetta Jul 2, 2025
5bd51df
updated example
brian-dellabetta Jul 2, 2025
3c216dd
DummyModel script
brian-dellabetta Jul 8, 2025
bbcdc8c
implement fuse_norm_linears
kylesayrs Jul 10, 2025
bd7f4d5
Merge branch 'kylesayrs/fuse-helpers' into bdellabe/transform-modifier
kylesayrs Jul 10, 2025
f5c2150
R1 working
kylesayrs Jul 11, 2025
dc5c30c
add r2, increase precision
kylesayrs Jul 11, 2025
7172c26
spinquant modifier
kylesayrs Jul 11, 2025
9298e82
remove space
kylesayrs Jul 11, 2025
f77226d
use iterable
kylesayrs Jul 11, 2025
fdb64b5
add rotation validation
kylesayrs Jul 11, 2025
5daa2d5
embedding fusion
kylesayrs Jul 11, 2025
0e9af7b
add missing norm fusion
kylesayrs Jul 12, 2025
fce83be
use norm mappings
kylesayrs Jul 12, 2025
a979f8a
break into separate files
kylesayrs Jul 12, 2025
4cab29e
small cleanup
kylesayrs Jul 12, 2025
f1cc987
cleanup
kylesayrs Jul 14, 2025
a7bb2e2
more cleanup
kylesayrs Jul 14, 2025
0cf0188
make new weight on cpu
kylesayrs Jul 14, 2025
53ea307
standardize, make modifier serializable
kylesayrs Jul 14, 2025
4b4257f
add compress model script
kylesayrs Jul 14, 2025
dc7ac1a
use untie_word_embeddings
kylesayrs Jul 15, 2025
8542f8d
style
kylesayrs Jul 15, 2025
b1e637e
better registery logic
kylesayrs Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions compress_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import SpinQuantModifier
from llmcompressor.utils import dispatch_for_generation

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, help="Model stub to compress")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The --model_id argument is essential for this script to run. Consider making it a required argument to provide a clearer usage error to the user if it's not provided.

Suggested change
parser.add_argument("--model_id", type=str, help="Model stub to compress")
parser.add_argument("--model_id", type=str, required=True, help="Model stub to compress")

parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier")
parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)")
return parser.parse_args()

if __name__ == "__main__":
args = parse_args()

# Select model and load it.
MODEL_ID = args.model_id
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Configure the quantization algorithm to run.
recipe = []
if args.transform_type:
recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type))

if args.scheme:
recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"]))

# Apply algorithms.
oneshot(
model=model,
recipe=recipe,
dataset="ultrachat_200k",
splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"},
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
86 changes: 86 additions & 0 deletions examples/transform/llama3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import SpinQuantModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# * apply spinquant transforms to model in order to make quantization easier
# * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = [
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(
model=model,
recipe=recipe,
dataset=ds,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
7 changes: 3 additions & 4 deletions src/llmcompressor/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from llmcompressor.pytorch.model_load.helpers import parse_dtype
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
untie_word_embeddings,
)
from llmcompressor.transformers.utils.helpers import (
detect_last_checkpoint,
Expand Down Expand Up @@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"):
)

# untie tie_word_embeddings weights
patch_tied_tensors_bug(model_args.model)
if not model_args.tie_word_embeddings:
untie_word_embeddings(model_args.model)

# wrap model.save_pretrained
modify_save_pretrained(model_args.model)
Expand Down Expand Up @@ -143,7 +144,6 @@ def initialize_model_from_path(
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)

Expand All @@ -156,7 +156,6 @@ def initialize_model_from_path(
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa

from .fuse import *
from .prepare import *
58 changes: 58 additions & 0 deletions src/llmcompressor/modeling/fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Iterable

import torch
from compressed_tensors import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm

__all__ = ["normalize_embedding", "fuse_norm_linears"]


PRECISION = torch.float64


def normalize_embedding(embedding: torch.nn.Module):
if isinstance(embedding, (torch.nn.Embedding)):
with align_module_device(embedding):
weight_dtype = embedding.weight.dtype
weight = embedding.weight.to(PRECISION)
new_weight = weight - weight.mean(dim=-1, keepdim=True)
new_weight = new_weight.to(weight_dtype)

update_offload_parameter(embedding, "weight", new_weight)

else:
raise ValueError(f"Cannot normalize embedding of type {type(embedding)}")


def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
"""
Fuse a norm layer into subsequent linear layers. This useful for ensuring transform
invariance between norm and linear layers.

Note that a model cannot be properly trained after its norms have been fused

:param norm: norm layer whose weight will be fused into subsequent linears
:param linears: linear layers which directly follow the norm layer
"""
if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)):
for linear in linears:
# NOTE: spinquant does this op in float64
exec_device = get_execution_device(norm)
with align_module_device(norm, exec_device), align_module_device(
linear, exec_device
):
weight_dtype = linear.weight.dtype
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
new_weight = new_weight.to(weight_dtype)

update_offload_parameter(linear, "weight", new_weight)

new_norm_weight = torch.ones_like(norm.weight, device="cpu")
update_offload_parameter(norm, "weight", new_norm_weight)

else:
raise ValueError(f"Cannot fuse norm of type {type(norm)}")
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .spinquant import SpinQuantModifier
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/spinquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
Loading
Loading