Skip to content

Commit 19153e1

Browse files
authored
Add total params to metadata + cleanup (#207)
* add total params to metadata + cleanup * comments
1 parent 1db99d4 commit 19153e1

File tree

11 files changed

+56
-285
lines changed

11 files changed

+56
-285
lines changed

mlx_lm/MERGE.md

Lines changed: 0 additions & 50 deletions
This file was deleted.

mlx_lm/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"fuse",
1515
"generate",
1616
"lora",
17-
"merge",
1817
"server",
1918
"manage",
2019
"upload",

mlx_lm/awq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,10 @@ def main():
594594
)
595595

596596
config = update_config(model, config)
597-
weights = dict(tree_flatten(model.parameters()))
598597
save(
599598
args.mlx_path,
600599
model_path,
601-
weights,
600+
model,
602601
tokenizer,
603602
config,
604603
hf_repo=args.model,

mlx_lm/convert.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import mlx.core as mx
88
import mlx.nn as nn
9-
from mlx.utils import tree_flatten
9+
from mlx.utils import tree_map_with_path
1010

1111
from .utils import (
1212
dequantize_model,
@@ -120,44 +120,36 @@ def convert(
120120

121121
if dtype is None:
122122
dtype = config.get("torch_dtype", None)
123-
weights = dict(tree_flatten(model.parameters()))
124123
if dtype in MODEL_CONVERSION_DTYPES:
125124
print("[INFO] Using dtype:", dtype)
126125
dtype = getattr(mx, dtype)
126+
cast_predicate = getattr(model, "cast_predicate", lambda _: True)
127127

128-
if hasattr(model, "cast_predicate"):
129-
cast_predicate = model.cast_predicate()
130-
else:
131-
cast_predicate = lambda _: True
132-
weights = {
133-
k: (
134-
v.astype(dtype)
135-
if cast_predicate(k) and mx.issubdtype(v.dtype, mx.floating)
136-
else v
137-
)
138-
for k, v in weights.items()
139-
}
128+
def set_dtype(k, v):
129+
if cast_predicate(k) and mx.issubdtype(v.dtype, mx.floating):
130+
return v.astype(dtype)
131+
else:
132+
return v
133+
134+
model.update(tree_map_with_path(set_dtype, model.parameters()))
140135

141136
if quantize and dequantize:
142137
raise ValueError("Choose either quantize or dequantize, not both.")
143138

144139
if quantize:
145140
print("[INFO] Quantizing")
146-
model.load_weights(list(weights.items()))
147-
weights, config = quantize_model(
141+
model, config = quantize_model(
148142
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
149143
)
150144

151145
if dequantize:
152146
print("[INFO] Dequantizing")
153147
model = dequantize_model(model)
154-
weights = dict(tree_flatten(model.parameters()))
155148

156-
del model
157149
save(
158150
mlx_path,
159151
model_path,
160-
weights,
152+
model,
161153
tokenizer,
162154
config,
163155
hf_repo=hf_path,

mlx_lm/dwq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def main():
236236
save(
237237
args.mlx_path,
238238
model_path,
239-
dict(tree_flatten(q_model.parameters())),
239+
q_model,
240240
tokenizer,
241241
config,
242242
hf_repo=args.model,

mlx_lm/fuse.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,16 @@ def main() -> None:
8686
model = dequantize(model)
8787
config.pop("quantization", None)
8888

89-
weights = dict(tree_flatten(model.parameters()))
90-
9189
save_path = Path(args.save_path)
9290
hf_path = args.hf_path or (args.model if not Path(args.model).exists() else None)
9391
save(
9492
save_path,
9593
model_path,
96-
weights,
94+
model,
9795
tokenizer,
9896
config,
9997
hf_repo=hf_path,
100-
donate_weights=False,
98+
donate_model=False,
10199
)
102100

103101
if args.export_gguf:
@@ -106,6 +104,7 @@ def main() -> None:
106104
raise ValueError(
107105
f"Model type {model_type} not supported for GGUF conversion."
108106
)
107+
weights = dict(tree_flatten(model.parameters()))
109108
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
110109

111110
if args.upload_repo is not None:

mlx_lm/merge.py

Lines changed: 0 additions & 176 deletions
This file was deleted.

mlx_lm/tuner/utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,20 +263,24 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
263263
return model
264264

265265

266-
def nparams(module):
267-
if hasattr(module, "bits"):
268-
n = 0 if not hasattr(module, "bias") else module.bias.size
269-
return n + module.weight.size * 32 // module.bits
270-
return sum(v.size for _, v in tree_flatten(module.parameters()))
271-
272-
273-
def print_trainable_parameters(model):
266+
def get_total_parameters(model):
274267
leaf_modules = tree_flatten(
275268
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
276269
)
277-
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
270+
271+
def nparams(m):
272+
if hasattr(m, "bits"):
273+
n = 0 if not hasattr(m, "bias") else m.bias.size
274+
return n + m.weight.size * 32 // m.bits
275+
return sum(v.size for _, v in tree_flatten(m.parameters()))
276+
277+
return sum(nparams(m) for _, m in leaf_modules)
278+
279+
280+
def print_trainable_parameters(model):
281+
total_p = get_total_parameters(model) / 1e6
278282
trainable_p = (
279-
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
283+
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 1e6
280284
)
281285
print(
282286
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "

0 commit comments

Comments
 (0)