Skip to content

Commit 5f24384

Browse files
brian-dellabettakylesayrsrahul-tuli
authored
Kylesayrs/update readme (#252)
* WIP Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Fix errors in scripts and notebooks in `examples/` and drop `sparseml` dependence (#247) * first pass, awaiting team feedback * drop hf-transfer nonsense * remaining example files * black/isort * Apply suggestions from code review Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> * notebook example cleanup Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * update compressed-tensors examples for QDQ and actual compression using torch hooks Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * Update examples/llama_1.1b/ex_config_quantization.py Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> * f string typo Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * updates from codereview Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent 22c09f3 commit 5f24384

File tree

9 files changed

+323
-165
lines changed

9 files changed

+323
-165
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,5 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
162+
examples/**/*.safetensors

examples/bit_packing/ex_quantize_and_pack.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,46 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from tqdm import tqdm
16-
from torch.utils.data import RandomSampler
15+
####
16+
#
17+
# The following example shows how to run QDQ inside `compressed-tensors`
18+
# QDQ (quantize & de-quantize) is a way to evaluate quantized model
19+
# accuracy but will not lead to a runtime speedup.
20+
# See `../llama_1.1b/ex_config_quantization.py` to go beyond QDQ
21+
# and quantize models that will run more performantly.
22+
#
23+
####
24+
25+
from pathlib import Path
26+
27+
import torch
28+
from compressed_tensors.compressors import ModelCompressor
1729
from compressed_tensors.quantization import (
18-
apply_quantization_config,
19-
freeze_module_quantization,
2030
QuantizationConfig,
2131
QuantizationStatus,
32+
apply_quantization_config,
2233
)
23-
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
24-
from sparseml.transformers.finetune.data.base import TextGenerationDataset
25-
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
34+
from datasets import load_dataset
2635
from torch.utils.data import DataLoader
27-
from sparseml.pytorch.utils import tensors_to_device
28-
import torch
29-
from compressed_tensors.compressors import ModelCompressor
36+
from tqdm import tqdm
37+
from transformers import AutoModelForCausalLM, AutoTokenizer
3038

31-
config_file = "int4_config.json"
39+
40+
config_file = Path(__file__).parent / "int4_config.json"
3241
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
33-
dataset_name = "open_platypus"
42+
dataset_name = "garage-bAInd/Open-Platypus"
3443
split = "train"
3544
num_calibration_samples = 128
3645
max_seq_length = 512
3746
pad_to_max_length = False
3847
output_dir = "./llama1.1b_new_quant_out_test_packing"
3948
device = "cuda:0" if torch.cuda.is_available() else "cpu"
4049

41-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype="auto")
50+
model = AutoModelForCausalLM.from_pretrained(
51+
model_name, device_map=device, torch_dtype="auto"
52+
)
4253
model.eval() # no grad or updates needed for base model
43-
config = QuantizationConfig.parse_file(config_file)
54+
config = QuantizationConfig.model_validate_json(config_file.read_text())
4455

4556
# set status to calibration
4657
config.quantization_status = QuantizationStatus.CALIBRATION
@@ -49,39 +60,35 @@
4960
apply_quantization_config(model, config)
5061

5162
# create dataset
63+
dataset = load_dataset(dataset_name, split=f"train[:{num_calibration_samples}]")
5264
tokenizer = AutoTokenizer.from_pretrained(model_name)
53-
data_args = DataTrainingArguments(
54-
dataset=dataset_name,
55-
max_seq_length=max_seq_length,
56-
pad_to_max_length=pad_to_max_length,
57-
)
58-
dataset_manager = TextGenerationDataset.load_from_registry(
59-
data_args.dataset,
60-
data_args=data_args,
61-
split=split,
62-
tokenizer=tokenizer,
63-
)
64-
calib_dataset = dataset_manager.tokenize_and_process(
65-
dataset_manager.get_raw_dataset()
66-
)
65+
66+
67+
def tokenize_function(examples):
68+
return tokenizer(
69+
examples["output"], padding=False, truncation=True, max_length=1024
70+
)
71+
72+
73+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
74+
6775
data_loader = DataLoader(
68-
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(calib_dataset)
76+
tokenized_dataset,
77+
batch_size=1,
6978
)
7079

71-
# run calibration
7280
with torch.no_grad():
7381
for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"):
74-
sample = tensors_to_device(sample, "cuda:0")
82+
sample = {k: v.to(model.device) for k, v in sample.items()}
7583
_ = model(**sample)
7684

7785
if idx >= num_calibration_samples:
7886
break
7987

80-
# freeze params after calibration
81-
model.apply(freeze_module_quantization)
82-
83-
# apply compression
88+
# convert model to QDQ model
8489
compressor = ModelCompressor(quantization_config=config)
8590
compressed_state_dict = compressor.compress(model)
91+
92+
# save QDQ model
8693
model.save_pretrained(output_dir, state_dict=compressed_state_dict)
87-
compressor.update_config(output_dir)
94+
compressor.update_config(output_dir)

examples/bit_packing/int4_config.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,5 @@
1212
},
1313
"targets": ["Linear"]
1414
}
15-
},
16-
"ignore": ["lm_head"]
15+
}
1716
}

examples/bitmask_compression.ipynb

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
},
1616
{
1717
"cell_type": "code",
18-
"execution_count": 8,
18+
"execution_count": 1,
1919
"metadata": {},
2020
"outputs": [],
2121
"source": [
@@ -29,7 +29,7 @@
2929
},
3030
{
3131
"cell_type": "code",
32-
"execution_count": 9,
32+
"execution_count": 2,
3333
"metadata": {},
3434
"outputs": [
3535
{
@@ -40,30 +40,30 @@
4040
" (embed_tokens): Embedding(32000, 768)\n",
4141
" (layers): ModuleList(\n",
4242
" (0-11): 12 x LlamaDecoderLayer(\n",
43-
" (self_attn): LlamaSdpaAttention(\n",
43+
" (self_attn): LlamaAttention(\n",
4444
" (q_proj): Linear(in_features=768, out_features=768, bias=False)\n",
4545
" (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
4646
" (v_proj): Linear(in_features=768, out_features=768, bias=False)\n",
4747
" (o_proj): Linear(in_features=768, out_features=768, bias=False)\n",
48-
" (rotary_emb): LlamaRotaryEmbedding()\n",
4948
" )\n",
5049
" (mlp): LlamaMLP(\n",
5150
" (gate_proj): Linear(in_features=768, out_features=2048, bias=False)\n",
5251
" (up_proj): Linear(in_features=768, out_features=2048, bias=False)\n",
5352
" (down_proj): Linear(in_features=2048, out_features=768, bias=False)\n",
5453
" (act_fn): SiLU()\n",
5554
" )\n",
56-
" (input_layernorm): LlamaRMSNorm()\n",
57-
" (post_attention_layernorm): LlamaRMSNorm()\n",
55+
" (input_layernorm): LlamaRMSNorm((768,), eps=1e-05)\n",
56+
" (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-05)\n",
5857
" )\n",
5958
" )\n",
60-
" (norm): LlamaRMSNorm()\n",
59+
" (norm): LlamaRMSNorm((768,), eps=1e-05)\n",
60+
" (rotary_emb): LlamaRotaryEmbedding()\n",
6161
" )\n",
6262
" (lm_head): Linear(in_features=768, out_features=32000, bias=False)\n",
6363
")"
6464
]
6565
},
66-
"execution_count": 9,
66+
"execution_count": 2,
6767
"metadata": {},
6868
"output_type": "execute_result"
6969
}
@@ -77,14 +77,14 @@
7777
},
7878
{
7979
"cell_type": "code",
80-
"execution_count": 10,
80+
"execution_count": 3,
8181
"metadata": {},
8282
"outputs": [
8383
{
8484
"name": "stdout",
8585
"output_type": "stream",
8686
"text": [
87-
"The example layer model.layers.0.self_attn.q_proj.weight has sparsity 0.50%\n"
87+
"The example layer model.layers.0.self_attn.q_proj.weight has sparsity 50%\n"
8888
]
8989
}
9090
],
@@ -93,42 +93,42 @@
9393
"state_dict = model.state_dict()\n",
9494
"state_dict.keys()\n",
9595
"example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n",
96-
"print(f\"The example layer {example_layer} has sparsity {torch.sum(state_dict[example_layer] == 0).item() / state_dict[example_layer].numel():.2f}%\")"
96+
"print(f\"The example layer {example_layer} has sparsity {100 * state_dict[example_layer].eq(0).sum().item() / state_dict[example_layer].numel():.0f}%\")"
9797
]
9898
},
9999
{
100100
"cell_type": "code",
101-
"execution_count": 11,
101+
"execution_count": 4,
102102
"metadata": {},
103103
"outputs": [
104104
{
105105
"name": "stdout",
106106
"output_type": "stream",
107107
"text": [
108-
"The model is 31.67% sparse overall\n"
108+
"The model is 32% sparse overall\n"
109109
]
110110
}
111111
],
112112
"source": [
113-
"# we can inspect to total sparisity of the state_dict\n",
113+
"# we can inspect to total sparsity of the state_dict\n",
114114
"total_num_parameters = 0\n",
115115
"total_num_zero_parameters = 0\n",
116116
"for key in state_dict:\n",
117117
" total_num_parameters += state_dict[key].numel()\n",
118118
" total_num_zero_parameters += state_dict[key].eq(0).sum().item()\n",
119-
"print(f\"The model is {total_num_zero_parameters/total_num_parameters*100:.2f}% sparse overall\")"
119+
"print(f\"The model is {total_num_zero_parameters/total_num_parameters*100:.0f}% sparse overall\")"
120120
]
121121
},
122122
{
123123
"cell_type": "code",
124-
"execution_count": 12,
124+
"execution_count": 5,
125125
"metadata": {},
126126
"outputs": [
127127
{
128128
"name": "stderr",
129129
"output_type": "stream",
130130
"text": [
131-
"Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.92it/s]\n"
131+
"Compressing model: 100%|██████████| 111/111 [00:00<00:00, 313.39it/s]\n"
132132
]
133133
},
134134
{
@@ -168,7 +168,7 @@
168168
},
169169
{
170170
"cell_type": "code",
171-
"execution_count": 13,
171+
"execution_count": 6,
172172
"metadata": {},
173173
"outputs": [
174174
{
@@ -185,8 +185,8 @@
185185
"## load the uncompressed safetensors to memory ##\n",
186186
"state_dict_1 = {}\n",
187187
"with safe_open('model.safetensors', framework=\"pt\") as f:\n",
188-
" for key in f.keys():\n",
189-
" state_dict_1[key] = f.get_tensor(key)\n",
188+
" for key in f.keys():\n",
189+
" state_dict_1[key] = f.get_tensor(key)\n",
190190
"\n",
191191
"## load the compressed-tensors to memory ##\n",
192192
"config = BitmaskConfig() # we need to specify the method for decompression\n",

0 commit comments

Comments
 (0)