@@ -45,49 +45,134 @@ It contains instructions to prune the model to 2:4 sparsity, run one epoch of re
45
45
and quantize to 4 bits in one show using GPTQ.
46
46
47
47
``` python
48
+ from pathlib import Path
49
+
48
50
import torch
49
- from transformers import AutoModelForCausalLM
51
+ from loguru import logger
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ from llmcompressor import oneshot, train
50
55
56
+ # load the model in as bfloat16 to save on memory and compute
51
57
model_stub = " neuralmagic/Llama-2-7b-ultrachat200k"
52
58
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype = torch.bfloat16)
59
+ tokenizer = AutoTokenizer.from_pretrained(model_stub)
53
60
61
+ # uses LLM Compressor's built-in preprocessing for ultra chat
54
62
dataset = " ultrachat-200k"
55
- splits = {" calibration" : " train_gen[:5%]" , " train" : " train_gen" }
56
63
64
+ # Select the recipe for 2 of 4 sparsity and 4-bit activation quantization
57
65
recipe = " 2of4_w4a16_recipe.yaml"
66
+
67
+ # save location of quantized model
68
+ output_dir = " output_llama7b_2of4_w4a16_channel"
69
+ output_path = Path(output_dir)
70
+
71
+ # set dataset config parameters
72
+ splits = {" calibration" : " train_gen[:5%]" , " train" : " train_gen" }
73
+ max_seq_length = 512
74
+ num_calibration_samples = 512
75
+
76
+ # set training parameters for finetuning
77
+ num_train_epochs = 0.01
78
+ logging_steps = 500
79
+ save_steps = 5000
80
+ gradient_checkpointing = True # saves memory during training
81
+ learning_rate = 0.0001
82
+ bf16 = False # using full precision for training
83
+ lr_scheduler_type = " cosine"
84
+ warmup_ratio = 0.1
85
+ preprocessing_num_workers = 64
86
+
87
+ oneshot_kwargs = dict (
88
+ dataset = dataset,
89
+ recipe = recipe,
90
+ num_calibration_samples = num_calibration_samples,
91
+ preprocessing_num_workers = preprocessing_num_workers,
92
+ splits = splits,
93
+ )
94
+
95
+ training_kwargs = dict (
96
+ bf16 = bf16,
97
+ max_seq_length = max_seq_length,
98
+ num_train_epochs = num_train_epochs,
99
+ logging_steps = logging_steps,
100
+ save_steps = save_steps,
101
+ gradient_checkpointing = gradient_checkpointing,
102
+ learning_rate = learning_rate,
103
+ lr_scheduler_type = lr_scheduler_type,
104
+ warmup_ratio = warmup_ratio,
105
+ )
58
106
```
59
107
60
- ## Step 2: Run sparsification using ` apply `
61
- The ` apply ` function applies the given recipe to our model and dataset.
62
- The hardcoded kwargs may be altered based on each model's needs.
63
- After running, the sparsified model will be saved to ` output_llama7b_2of4_w4a16_channel ` .
108
+ ## Step 2: Run ` sparsification ` , ` fine-tuning ` , and ` quantization `
109
+ The compression process now runs in three stages: sparsification, fine-tuning, and quantization.
110
+ Each stage saves the intermediate model outputs to the ` output_llama7b_2of4_w4a16_channel ` directory.
64
111
65
112
``` python
66
- from llmcompressor.transformers import apply
113
+ from llmcompressor import oneshot, train
114
+ from pathlib import Path
67
115
68
116
output_dir = " output_llama7b_2of4_w4a16_channel"
117
+ output_path = Path(output_dir)
69
118
70
- apply(
119
+ # 1. Oneshot sparsification: apply pruning
120
+ oneshot(
71
121
model = model,
72
122
dataset = dataset,
73
123
recipe = recipe,
74
- bf16 = False , # use full precision for training
124
+ splits = splits,
125
+ num_calibration_samples = 512 ,
126
+ preprocessing_num_workers = 8 ,
75
127
output_dir = output_dir,
128
+ stage = " sparsity_stage" ,
129
+ )
130
+
131
+ # 2. Sparse fine-tuning: improve accuracy on pruned model
132
+ train(
133
+ model = output_path / " sparsity_stage" ,
134
+ dataset = dataset,
135
+ recipe = recipe,
76
136
splits = splits,
77
- max_seq_length = 512 ,
78
137
num_calibration_samples = 512 ,
138
+ preprocessing_num_workers = 8 ,
139
+ bf16 = False ,
140
+ max_seq_length = 512 ,
79
141
num_train_epochs = 0.5 ,
80
142
logging_steps = 500 ,
81
143
save_steps = 5000 ,
82
144
gradient_checkpointing = True ,
83
145
learning_rate = 0.0001 ,
84
146
lr_scheduler_type = " cosine" ,
85
147
warmup_ratio = 0.1 ,
148
+ output_dir = output_dir,
149
+ stage = " finetuning_stage" ,
150
+ )
151
+
152
+ # 3. Oneshot quantization: compress model weights to lower precision
153
+ quantized_model = oneshot(
154
+ model = output_path / " finetuning_stage" ,
155
+ dataset = dataset,
156
+ recipe = recipe,
157
+ splits = splits,
158
+ num_calibration_samples = 512 ,
159
+ preprocessing_num_workers = 8 ,
160
+ output_dir = output_dir,
161
+ stage = " quantization_stage" ,
162
+ )
163
+ quantized_model.save_pretrained(
164
+ f " { output_dir} /quantization_stage " , skip_sparsity_compression_stats = False
86
165
)
166
+ tokenizer.save_pretrained(f " { output_dir} /quantization_stage " )
87
167
88
168
```
89
169
90
170
### Custom Quantization
91
- The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are ` tensor ` , ` group ` and ` channel ` .
92
- The above recipe (` 2of4_w4a16_recipe.yaml ` ) uses channel-wise quantization specified by ` strategy: "channel" ` in its config group.
93
- To use quantize per tensor, change strategy from ` channel ` to ` tensor ` . To use group size quantization, change from ` channel ` to ` group ` and specify its value, say 128, by including ` group_size: 128 ` . A group size quantization example is shown in ` 2of4_w4a16_group-128_recipe.yaml ` .
171
+ The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are tensor, group, and channel.
172
+
173
+ The recipe ` (2of4_w4a16_recipe.yaml) ` uses channel-wise quantization ` (strategy: "channel") ` .
174
+ To change the quantization strategy, edit the recipe file accordingly:
175
+
176
+ Use ` tensor ` for per-tensor quantization
177
+ Use ` group ` for group-wise quantization and specify the group_size parameter (e.g., 128)
178
+ See ` 2of4_w4a16_group-128_recipe.yaml ` for a group-size example
0 commit comments