Skip to content

Commit ba617db

Browse files
kylesayrsbrian-dellabetta
authored andcommitted
wip
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 50bb656 commit ba617db

File tree

5 files changed

+221
-0
lines changed

5 files changed

+221
-0
lines changed

examples/transform/llama3_example.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor.modifiers.quantization import GPTQModifier
5+
from llmcompressor.modifiers.transform import TransformModifier
6+
from llmcompressor.transformers import oneshot
7+
8+
# Select model and load it.
9+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
10+
11+
model = AutoModelForCausalLM.from_pretrained(
12+
MODEL_ID,
13+
device_map="auto",
14+
torch_dtype="auto",
15+
)
16+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
17+
18+
# Select calibration dataset.
19+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
20+
DATASET_SPLIT = "train_sft"
21+
22+
# Select number of samples. 512 samples is a good place to start.
23+
# Increasing the number of samples can improve accuracy.
24+
NUM_CALIBRATION_SAMPLES = 512
25+
MAX_SEQUENCE_LENGTH = 2048
26+
27+
# Load dataset and preprocess.
28+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
29+
ds = ds.shuffle(seed=42)
30+
31+
32+
def preprocess(example):
33+
return {
34+
"text": tokenizer.apply_chat_template(
35+
example["messages"],
36+
tokenize=False,
37+
)
38+
}
39+
40+
41+
ds = ds.map(preprocess)
42+
43+
44+
# Tokenize inputs.
45+
def tokenize(sample):
46+
return tokenizer(
47+
sample["text"],
48+
padding=False,
49+
max_length=MAX_SEQUENCE_LENGTH,
50+
truncation=True,
51+
add_special_tokens=False,
52+
)
53+
54+
55+
ds = ds.map(tokenize, remove_columns=ds.column_names)
56+
57+
# Configure the quantization algorithm to run.
58+
# * quantize the weights to 4 bit with GPTQ with a group size 128
59+
recipe = [
60+
TransformModifier(),
61+
GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
62+
]
63+
64+
# Apply algorithms.
65+
oneshot(
66+
model=model,
67+
dataset=ds,
68+
recipe=recipe,
69+
max_seq_length=MAX_SEQUENCE_LENGTH,
70+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
71+
)
72+
73+
# Confirm generations of the quantized model look sane.
74+
print("\n\n")
75+
print("========== SAMPLE GENERATION ==============")
76+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
77+
output = model.generate(input_ids, max_new_tokens=100)
78+
print(tokenizer.decode(output[0]))
79+
print("==========================================\n\n")
80+
81+
# Save to disk compressed.
82+
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
83+
model.save_pretrained(SAVE_DIR, save_compressed=True)
84+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .transform import TransformModifier
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig
2+
3+
4+
QUIP = TransformConfig(
5+
config_groups={
6+
"v": TransformScheme(
7+
type="hadamard",
8+
apply=[
9+
TransformArgs(
10+
targets=["Linear"],
11+
location="input", # non-mergable
12+
ignore="lm_head",
13+
),
14+
TransformArgs(
15+
targets=["Linear"],
16+
location="weight_input",
17+
inverse=True,
18+
ignore="lm_head",
19+
),
20+
],
21+
randomize=True,
22+
),
23+
"u": TransformScheme(
24+
type="hadamard",
25+
apply=[
26+
TransformArgs(
27+
targets=["Linear"],
28+
location="weight_output",
29+
ignore="lm_head",
30+
),
31+
TransformArgs(
32+
targets=["Linear"],
33+
location="output", # non-mergable
34+
inverse=True,
35+
ignore="lm_head"
36+
),
37+
],
38+
randomize=True,
39+
),
40+
}
41+
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig
2+
3+
4+
LLAMA_SPINQUANT = TransformConfig(
5+
transform_groups={
6+
"R1": TransformScheme(
7+
type="hadamard",
8+
apply=[
9+
TransformArgs(
10+
targets=["embed_tokens", "o_proj", "down_proj"],
11+
location="weight_output",
12+
),
13+
TransformArgs(
14+
targets=[
15+
"q_proj",
16+
"k_proj",
17+
"v_proj",
18+
"up_proj",
19+
"gate_proj",
20+
"lm_head",
21+
],
22+
location="weight_input",
23+
inverse=True,
24+
),
25+
],
26+
),
27+
"R2": TransformScheme(
28+
type="hadamard",
29+
apply=[
30+
TransformArgs(
31+
targets=["v_proj"],
32+
location="weight_output",
33+
),
34+
TransformArgs(
35+
targets=["o_proj"], location="weight_input", inverse=True
36+
),
37+
],
38+
),
39+
"R3": TransformScheme(
40+
type="hadamard",
41+
apply=[
42+
TransformArgs(
43+
targets=["self_attn"],
44+
location="k_cache",
45+
),
46+
TransformArgs(
47+
targets=["self_attn"],
48+
location="q_attn",
49+
),
50+
],
51+
),
52+
"R4": TransformScheme(
53+
type="hadamard",
54+
apply=[
55+
TransformArgs(
56+
targets=["down_proj"],
57+
location="input",
58+
),
59+
TransformArgs(
60+
targets=["down_proj"], location="weight_input", inverse=True
61+
),
62+
],
63+
),
64+
}
65+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Dict, Optional
2+
3+
from llmcompressor.core import State
4+
from llmcompressor.modifiers import Modifier
5+
6+
from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory
7+
8+
from .template.quip import QUIP
9+
10+
class TransformModifier(Modifier):
11+
preset_config: Optional[str] = None
12+
config_groups: Optional[Dict[str, TransformScheme]] = None
13+
14+
# model validator to validate both preset and config gropus are not provided
15+
16+
def on_initialize(self, state: State, **kwargs):
17+
if self.preset_config is not None:
18+
# import config template and customize to model
19+
pass
20+
21+
22+
#config = TransformConfig(config_groups=self.config_groups)
23+
config = QUIP
24+
25+
# TODO: use CT-provided apply_transform_config
26+
for name, scheme in config.config_groups.items():
27+
factory = TransformFactory.from_scheme(scheme, name=name)
28+
factory.apply_to_model(state.model)

0 commit comments

Comments
 (0)