Skip to content

Commit d947364

Browse files
authored
[Experimental] Mistral-format FP8 quantization (#1359)
https://huggingface.co/nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8 ``` vllm serve nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8 --tokenizer_mode mistral --config_format mistral --load_format mistral --quantization fp8 lm_eval --model local-completions --model_args model=nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8,tokenizer=mistralai/Mistral-Small-3.1-24B-Instruct-2503,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=500,tokenized_requests=False --tasks gsm8k --num_fewshot 5 local-completions (model=nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8,tokenizer=mistralai/Mistral-Small-3.1-24B-Instruct-2503,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=500,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1 |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8886|± |0.0087| | | |strict-match | 5|exact_match|↑ |0.8848|± |0.0088| ``` --------- Signed-off-by: mgoin <michael@neuralmagic.com>
1 parent c1c8541 commit d947364

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

experimental/mistral/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Mistral-format model compression (experimental)
2+
3+
This folder contains tools for compressing Mistral-format models, like `mistralai/Devstral-Small-2505` and `mistralai/Magistral-Small-2506`.
4+
5+
## FP8 W8A8 Quantization
6+
7+
This script quantizes Mistral-format models to FP8. It is not for use with HuggingFace-format models.
8+
9+
### 1. Download the model
10+
11+
Download the model and save it to a new "FP8" folder. We use `mistralai/Magistral-Small-2506` as an example.
12+
13+
```bash
14+
huggingface-cli download mistralai/Magistral-Small-2506 --local-dir Magistral-Small-2506-FP8
15+
```
16+
17+
### 2. Clean up HuggingFace-specific files
18+
19+
Models from the Hub often include files for both the native Mistral format and the HuggingFace `transformers` format. This script works on the native format, so the `transformers` files should be removed to avoid confusion.
20+
21+
The HuggingFace-specific files are typically `config.json`, `model-000*-of-000*.safetensors`, and `model.safetensors.index.json`. The `params.json`, `tekken.json` and `consolidated.safetensors` are for the native format.
22+
23+
Before deleting, it's a good idea to look at the files in the directory to understand what you're removing.
24+
25+
Once you're ready, remove the `transformers`-specific files:
26+
27+
```bash
28+
rm Magistral-Small-2506/config.json Magistral-Small-2506/model.safetensors.index.json Magistral-Small-2506-FP8/model-000*
29+
```
30+
31+
### 3. Run the quantization script
32+
33+
Now, run the FP8 quantization script on the directory. This will modify the `.safetensors` files in-place and update `params.json` and `consolidated.safetensors`.
34+
35+
```bash
36+
python fp8_quantize.py Magistral-Small-2506-FP8
37+
```
38+
39+
### 4. Use the quantized model
40+
41+
The model should now be ready to use in vLLM!
42+
43+
```bash
44+
vllm serve Magistral-Small-2506-FP8 --tokenizer-mode mistral --config-format mistral --load-format mistral --tool-call-parser mistral --enable-auto-tool-choice
45+
```

experimental/mistral/fp8_quantize.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import argparse
2+
import os
3+
import json
4+
import torch
5+
import safetensors.torch
6+
7+
def per_tensor_quantize(tensor):
8+
"""Quantize a tensor to FP8 using per-tensor static scaling factor."""
9+
finfo = torch.finfo(torch.float8_e4m3fn)
10+
if tensor.numel() == 0:
11+
min_val, max_val = torch.tensor(-16.0, dtype=tensor.dtype), torch.tensor(16.0, dtype=tensor.dtype)
12+
else:
13+
min_val, max_val = tensor.aminmax()
14+
amax = torch.maximum(min_val.abs(), max_val.abs())
15+
scale = finfo.max / amax.clamp(min=1e-12)
16+
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max).to(torch.float8_e4m3fn)
17+
scale = scale.float().reciprocal()
18+
return qweight, scale
19+
20+
def is_quantizable(name):
21+
"""Check if the tensor name indicates it can be quantized."""
22+
return name.startswith('layers.') and name.endswith(('.wk.weight', '.wo.weight', '.wq.weight', '.wv.weight', '.w1.weight', '.w2.weight', '.w3.weight'))
23+
24+
def process_safetensors_file(file_path):
25+
"""Process a single safetensors file in-place, quantizing weights to FP8."""
26+
print(f"Processing {file_path}")
27+
tensors = safetensors.torch.load_file(file_path)
28+
29+
modified_tensors = {}
30+
for name, tensor in tensors.items():
31+
if is_quantizable(name):
32+
print("Quantizing", name)
33+
qweight, scale = per_tensor_quantize(tensor)
34+
modified_tensors[name] = qweight
35+
modified_tensors[f"{name[:-len("weight")]}qscale_weight"] = scale
36+
else:
37+
modified_tensors[name] = tensor
38+
39+
safetensors.torch.save_file(modified_tensors, file_path)
40+
print(f"Updated {file_path} with quantized tensors")
41+
42+
def update_index_file(index_file_path):
43+
"""Update the index file for the quantized model."""
44+
print(f"Updating index file: {index_file_path}")
45+
with open(index_file_path, 'r') as f:
46+
index = json.load(f)
47+
48+
new_weight_map = {}
49+
for tensor_name, file_name in index['weight_map'].items():
50+
new_weight_map[tensor_name] = file_name
51+
if is_quantizable(tensor_name):
52+
new_weight_map[f"{tensor_name[:-len("weight")]}qscale_weight"] = file_name
53+
54+
index['weight_map'] = new_weight_map
55+
56+
# Recalculate total_size
57+
total_size = sum(os.path.getsize(os.path.join(os.path.dirname(index_file_path), file))
58+
for file in set(index['weight_map'].values()))
59+
index['metadata']['total_size'] = total_size
60+
61+
with open(index_file_path, 'w') as f:
62+
json.dump(index, f, indent=2)
63+
print(f"Updated index file {index_file_path}")
64+
65+
def update_config(config_file_path):
66+
"""Update the params.json file for the quantized model."""
67+
print(f"Updating config file: {config_file_path}")
68+
with open(config_file_path, 'r') as f:
69+
config = json.load(f)
70+
71+
config["quantization"] = {
72+
"config_groups": {
73+
"group_0": {
74+
"input_activations": {
75+
"dynamic": True,
76+
"num_bits": 8,
77+
"observer": None,
78+
"strategy": "token",
79+
"symmetric": True,
80+
"type": "float"
81+
},
82+
"targets": ["Linear"],
83+
"weights": {
84+
"dynamic": False,
85+
"num_bits": 8,
86+
"observer": "minmax",
87+
"strategy": "tensor",
88+
"symmetric": True,
89+
"type": "float"
90+
}
91+
}},
92+
"format": "float-quantized",
93+
"ignore": ["lm_head", "output"],
94+
"quant_method": "compressed-tensors",
95+
"quantization_status": "compressed"
96+
}
97+
98+
with open(config_file_path, 'w') as f:
99+
json.dump(config, f, indent=2)
100+
print(f"Updated config file {config_file_path}")
101+
102+
def process_directory(directory):
103+
"""Process all safetensors files in the given directory."""
104+
for filename in os.listdir(directory):
105+
file_path = os.path.join(directory, filename)
106+
if filename.endswith('.safetensors'):
107+
process_safetensors_file(file_path)
108+
elif filename == 'consolidated.safetensors.index.json':
109+
update_index_file(file_path)
110+
elif filename == 'params.json':
111+
update_config(file_path)
112+
else:
113+
print(f"Skipping unrecognized file: {filename}")
114+
115+
if __name__ == '__main__':
116+
parser = argparse.ArgumentParser(description='Convert mistral safetensors model to FP8 in-place.')
117+
parser.add_argument('directory', type=str, help='The directory containing the safetensors files and index file.')
118+
119+
args = parser.parse_args()
120+
process_directory(args.directory)

0 commit comments

Comments
 (0)