Skip to content

Commit 3be5153

Browse files
awniangeloskath
andauthored
Dynamic quants (#202)
* dynamic quants + reorg * readme * angelos fix * Change sensitivity metric * update version * fix rebase --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
1 parent 19153e1 commit 3be5153

File tree

8 files changed

+318
-33
lines changed

8 files changed

+318
-33
lines changed

mlx_lm/LEARNED_QUANTS.md

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
# Learned Quantization
22

3-
To reduce the quality loss from quantization MLX LM has two options:
3+
To reduce the quality loss from quantization MLX LM has several options:
44

55
- Distilled Weight Quantization (DWQ)
6-
- Activation-aware Weight Quantization (AWQ)[^1].
6+
- Activation-aware Weight Quantization (AWQ)[^1]
7+
- Dynamic quantization
78

8-
Both DWQ and AWQ use an example dataset to tune parameters of the model. DWQ
9-
fine-tunes non-quantized parameters (including quantization scales and biases)
10-
using the non-quantized model as a teacher. AWQ scales and clips the weights
11-
prior to quantization. The scaling and clipping values are found with a grid
12-
search minimizing the distance from the quantized hidden activations to the
13-
non-quantized hidden activations
9+
All methods use calibration data to tune parameters or hyper-parameters of the
10+
model. DWQ fine-tunes non-quantized parameters (including quantization scales
11+
and biases) using the non-quantized model as a teacher. AWQ scales and clips
12+
the weights prior to quantization. Dynamic quantization estimates the
13+
sensitivity of a model's outputs to each layer and uses a higher precision for
14+
layers which have higher sensitivity.
15+
16+
Dynamic quantization is the fastest to run. DWQ takes longer but typically
17+
yields better results. You can also cascade methods. For example a dynamically
18+
quantized model can be further refined with DWQ.
1419

1520
To get started, first install the requirements:
1621

1722
```
18-
pip install mlx-lm[lwq]
23+
pip install mlx-lm[quant]
1924
```
2025

2126
### DWQ
@@ -66,6 +71,30 @@ A few options to reduce memory use for DWQ:
6671
`--max-seq-length 512` reduces the memory and still gets good results.
6772
- Use a smaller batch size, e.g. `--batch-size 1`
6873

74+
### Dynamic Quantization
75+
76+
Use `mlx_lm.dynamic_quant` to generate a dynamic quantization of given model.
77+
For example:
78+
79+
```bash
80+
mlx_lm.dynamic_quant --model mistralai/Mistral-7B-Instruct-v0.3
81+
```
82+
83+
The script will estimate the sensitivity for each quantizable layer in the
84+
model. It will then quantize the model using higher precision (default 5 bits)
85+
for the more sensitive layers and lower precision (default 4 bits) for the
86+
rest. The script also saves a JSON file with each layer's sensitivities which
87+
saves needing to compute it multiple times to make different precision quants
88+
of the same model.
89+
90+
Some important options are:
91+
92+
- `--target-bpw`: The target bits-per-weight. For a given set of quantization
93+
parameters only certain ranges are possible. For example, with the default
94+
parameters a BPW in the range `[4.5, 5.5]` is achievable.
95+
- `--sensitivities`: A path to a precomputed sensitivities file.
96+
- `--low-bits`: The number of bits to use for the less sensitive layers.
97+
- `--high-bits`: The number of bits to use for the more sensitive layers.
6998

7099
### AWQ
71100

mlx_lm/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
if __name__ == "__main__":
77
subcommands = {
8-
"awq",
9-
"dwq",
8+
"quant.awq",
9+
"quant.dwq",
10+
"quant.dynamic_quant",
1011
"cache_prompt",
1112
"chat",
1213
"convert",

mlx_lm/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright © 2023-2024 Apple Inc.
22

3-
__version__ = "0.24.1"
3+
__version__ = "0.25.0"

mlx_lm/awq.py renamed to mlx_lm/quant/awq.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from mlx_lm.models.base import create_attention_mask
1616
from mlx_lm.models.switch_layers import SwitchLinear
17+
from mlx_lm.quant.utils import load_data
1718
from mlx_lm.utils import (
1819
fetch_from_hub,
1920
get_model_path,
@@ -510,23 +511,6 @@ def __call__(self, x: mx.array, *args, **kwargs):
510511
)
511512

512513

513-
def load_dataset(tokenizer, num_samples: int, sequence_length: int) -> mx.array:
514-
save_dir = Path.home() / ".cache/mlx-lm/calibration_v5.txt"
515-
if not save_dir.exists():
516-
save_dir.parent.mkdir(parents=True, exist_ok=True)
517-
url = "https://gist.githubusercontent.com/tristandruyen/9e207a95c7d75ddf37525d353e00659c/raw/571fda718462de863e5a0171078c175420c7649a/calibration_data_v5_rc.txt"
518-
request.urlretrieve(url, save_dir)
519-
with open(save_dir) as fid:
520-
texts = fid.read()
521-
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
522-
523-
# select random non-overlapping chunks
524-
tokens = tokens[: (tokens.size // sequence_length) * sequence_length]
525-
tokens = tokens.reshape(-1, sequence_length)
526-
segments = mx.random.permutation(tokens.shape[0])[:num_samples]
527-
return tokens[segments]
528-
529-
530514
def update_config(
531515
model: nn.Module,
532516
config: Dict[str, Any],
@@ -578,7 +562,7 @@ def main():
578562
if (awq_config := AWQ_MODEL_CONFIGS.get(model_type, None)) is None:
579563
raise NotImplementedError(f"AWQ support for {model_type} models NYI.")
580564

581-
calibration_data = load_dataset(tokenizer, args.num_samples, args.sequence_length)
565+
calibration_data = load_data(tokenizer, args.num_samples, args.sequence_length)
582566

583567
calibration_data = dist_split(calibration_data, group)
584568

File renamed without changes.

mlx_lm/quant/dynamic_quant.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
import argparse
4+
import copy
5+
import json
6+
import math
7+
8+
import mlx.core as mx
9+
import mlx.nn as nn
10+
import numpy as np
11+
from mlx.utils import tree_flatten, tree_map, tree_unflatten
12+
from tqdm import tqdm
13+
14+
from mlx_lm.quant.utils import load_data
15+
from mlx_lm.utils import (
16+
compute_bits_per_weight,
17+
fetch_from_hub,
18+
get_model_path,
19+
quantize_model,
20+
save,
21+
)
22+
23+
24+
def eval_ppl(model, data, batch_size=8):
25+
all_loss = 0.0
26+
ntoks = 0
27+
for s in range(0, len(data), batch_size):
28+
batch = data[s : s + batch_size]
29+
logits = model(batch[:, :-1]).astype(mx.float32)
30+
losses = nn.losses.cross_entropy(logits, batch[:, 1:])
31+
all_loss += losses.sum().item()
32+
ntoks += losses.size
33+
ppl = math.exp(all_loss / ntoks)
34+
return ppl
35+
36+
37+
def estimate_sensitivities(
38+
model,
39+
data,
40+
low_bits,
41+
low_group_size,
42+
high_bits,
43+
high_group_size,
44+
):
45+
batch_size = 4
46+
layers = tree_flatten(model.leaf_modules(), is_leaf=nn.Module.is_module)
47+
layers = {k: l for k, l in layers if hasattr(l, "to_quantized")}
48+
49+
q_model = copy.deepcopy(model)
50+
51+
def qdq(w, bits, group_size):
52+
w, s, b = mx.quantize(w, bits=bits, group_size=group_size)
53+
return mx.dequantize(w, scales=s, biases=b, bits=bits, group_size=group_size)
54+
55+
q_layers = copy.deepcopy(layers)
56+
for l in q_layers.values():
57+
l.weight = qdq(l.weight, low_bits, low_group_size)
58+
q_model.freeze()
59+
q_model.update_modules(tree_unflatten(list(q_layers.items())))
60+
61+
def log_norm(x):
62+
x = x.astype(mx.float32)
63+
return x - mx.logsumexp(x, axis=-1, keepdims=True)
64+
65+
def loss_fn(batch, targets):
66+
logprobs = log_norm(q_model(batch))
67+
return nn.losses.kl_div_loss(logprobs, targets, reduction="mean")
68+
69+
grad_accum = tree_map(lambda x: mx.zeros(x.shape), q_model.trainable_parameters())
70+
for e, s in tqdm(
71+
enumerate(range(0, len(data), batch_size)),
72+
total=len(data) // batch_size,
73+
desc="Estimating sensitivities",
74+
):
75+
batch = data[s : s + batch_size]
76+
targets = log_norm(model(batch))
77+
mx.eval(targets)
78+
_, grads = nn.value_and_grad(q_model, loss_fn)(batch, targets)
79+
grad_accum = tree_map(lambda x, y: x + y, grad_accum, grads)
80+
mx.eval(grad_accum)
81+
82+
def compute_sensitivity(gradient, low_q_weight, original_weight):
83+
n_batches = (len(data) + batch_size - 1) // batch_size
84+
gradient = gradient / n_batches
85+
high_q_weight = qdq(original_weight, high_bits, high_group_size)
86+
param_size = original_weight.size / 1e6
87+
alignment = (gradient * (low_q_weight - high_q_weight)).sum()
88+
return alignment / param_size
89+
90+
sensitivities = tree_map(
91+
compute_sensitivity,
92+
grad_accum,
93+
q_model.parameters(),
94+
model.parameters(),
95+
)
96+
mx.eval(sensitivities)
97+
98+
sensitivities = [(k[:-7], s.item()) for k, s in tree_flatten(sensitivities)]
99+
100+
return sensitivities
101+
102+
103+
def estimate_threshold(
104+
model,
105+
sensitivities,
106+
target_bpw,
107+
low_bits,
108+
low_group_size,
109+
high_bits,
110+
high_group_size,
111+
):
112+
def predicate(p, m, high_threshold):
113+
if not hasattr(m, "to_quantized"):
114+
return False
115+
if sensitivities[p] > high_threshold:
116+
return {"bits": high_bits, "group_size": high_group_size}
117+
return True
118+
119+
# Binary search for the threshold
120+
sens_vals = list(sensitivities.values())
121+
min_threshold = min(sens_vals)
122+
max_threshold = max(sens_vals)
123+
tolerance = 1e-3 * (max_threshold - min_threshold)
124+
while (max_threshold - min_threshold) > tolerance:
125+
mid = (max_threshold + min_threshold) / 2
126+
class_predicate = lambda p, m: predicate(p, m, mid)
127+
q_model = copy.deepcopy(model)
128+
nn.quantize(
129+
q_model,
130+
group_size=low_group_size,
131+
bits=low_bits,
132+
class_predicate=class_predicate,
133+
)
134+
bpw = compute_bits_per_weight(q_model)
135+
if bpw > target_bpw:
136+
min_threshold = mid
137+
else:
138+
max_threshold = mid
139+
140+
return (max_threshold + min_threshold) / 2
141+
142+
143+
def main():
144+
parser = argparse.ArgumentParser()
145+
parser.add_argument("--model", "-m", default="Qwen/Qwen3-0.6B-base")
146+
parser.add_argument(
147+
"--mlx-path", default="mlx_model", help="Path to save the model"
148+
)
149+
parser.add_argument("--seed", type=int, default=123)
150+
parser.add_argument(
151+
"--sensitivities",
152+
type=str,
153+
default=None,
154+
help="Path to a pre-computed sensitivity JSON file.",
155+
)
156+
parser.add_argument(
157+
"--target-bpw", type=float, default=5.0, help="Target bits per weight."
158+
)
159+
parser.add_argument("--low-bits", type=int, default=4)
160+
parser.add_argument("--low-group-size", type=int, default=64)
161+
parser.add_argument("--high-bits", type=int, default=5)
162+
parser.add_argument("--high-group-size", type=int, default=64)
163+
parser.add_argument(
164+
"--report-ppl",
165+
action="store_true",
166+
help="Compute the perplexity of the base and quantized models.",
167+
)
168+
169+
args = parser.parse_args()
170+
171+
group = mx.distributed.init()
172+
173+
if args.sensitivities is None:
174+
model_path = get_model_path(args.model, revision=None)
175+
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
176+
mx.random.seed(args.seed)
177+
data = load_data(tokenizer, num_samples=-1, sequence_length=512)
178+
179+
sensitivities = estimate_sensitivities(
180+
model,
181+
data,
182+
args.low_bits,
183+
args.low_group_size,
184+
args.high_bits,
185+
args.high_group_size,
186+
)
187+
model_name = args.model.replace("/", "_")
188+
with open(f"{model_name}_sensitivities.json", "w") as fid:
189+
json.dump(sensitivities, fid)
190+
else:
191+
with open(args.sensitivities, "r") as fid:
192+
sensitivities = json.load(fid)
193+
194+
sensitivities = dict(sensitivities)
195+
model_path = get_model_path(args.model, revision=None)
196+
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
197+
mx.random.seed(args.seed)
198+
data = load_data(tokenizer, num_samples=-1, sequence_length=512)
199+
200+
if args.report_ppl:
201+
ppl = eval_ppl(model, data)
202+
print(f"Original PPL: {ppl:.3f}")
203+
204+
threshold = estimate_threshold(
205+
model,
206+
sensitivities,
207+
target_bpw=args.target_bpw,
208+
low_bits=args.low_bits,
209+
low_group_size=args.low_group_size,
210+
high_bits=args.high_bits,
211+
high_group_size=args.high_group_size,
212+
)
213+
214+
def quant_predicate(p, m, _):
215+
if not hasattr(m, "to_quantized"):
216+
return False
217+
if sensitivities[p] > threshold:
218+
return {"bits": args.high_bits, "group_size": args.high_group_size}
219+
return True
220+
221+
model, config = quantize_model(
222+
model,
223+
config,
224+
q_group_size=args.low_group_size,
225+
q_bits=args.low_bits,
226+
quant_predicate=quant_predicate,
227+
)
228+
229+
if args.report_ppl:
230+
ppl = eval_ppl(model, data)
231+
print(f"Quantized PPL: {ppl:.3f}")
232+
233+
save(
234+
args.mlx_path,
235+
model_path,
236+
model,
237+
tokenizer,
238+
config,
239+
hf_repo=args.model,
240+
)
241+
242+
243+
if __name__ == "__main__":
244+
main()

mlx_lm/quant/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
from pathlib import Path
4+
5+
import mlx.core as mx
6+
7+
8+
def load_data(tokenizer, num_samples: int, sequence_length: int) -> mx.array:
9+
save_dir = Path.home() / ".cache/mlx-lm/calibration_v5.txt"
10+
if not save_dir.exists():
11+
from urllib import request
12+
13+
save_dir.parent.mkdir(parents=True, exist_ok=True)
14+
url = "https://gist.githubusercontent.com/tristandruyen/9e207a95c7d75ddf37525d353e00659c/raw/571fda718462de863e5a0171078c175420c7649a/calibration_data_v5_rc.txt"
15+
request.urlretrieve(url, save_dir)
16+
with open(save_dir) as fid:
17+
texts = fid.read()
18+
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
19+
20+
# select random non-overlapping chunks
21+
tokens = tokens[: (tokens.size // sequence_length) * sequence_length]
22+
tokens = tokens.reshape(-1, sequence_length)
23+
segments = mx.random.permutation(tokens.shape[0])
24+
if num_samples > 0:
25+
segments = segments[:num_samples]
26+
return tokens[segments]

0 commit comments

Comments
 (0)