Skip to content

Commit ee94992

Browse files
authored
Fix parameter count for quantized models (#137)
* fix parameter count * cleanup * try extension module
1 parent 169650a commit ee94992

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,13 @@ class LLMEvaluator {
193193
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
194194
}
195195
}
196+
let numParams = await modelContainer.perform {
197+
[] model, _ in
198+
return model.numParameters()
199+
}
200+
196201
self.modelInfo =
197-
"Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
202+
"Loaded \(modelConfiguration.id). Weights: \(numParams / (1024*1024))M"
198203
loadState = .loaded(modelContainer)
199204
return modelContainer
200205

Libraries/LLM/LLMModel.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,27 @@ public actor ModelContainer {
5757
}
5858
}
5959

60+
extension Module {
61+
62+
/// Compute the number of parameters in a possibly quantized model
63+
public func numParameters() -> Int {
64+
return leafModules().flattenedValues().map {
65+
mod -> Int in
66+
if let qlin = mod as? QuantizedLinear {
67+
return qlin.scales.size * qlin.groupSize
68+
} else if let qemb = mod as? QuantizedEmbedding {
69+
return qemb.scales.size * qemb.groupSize
70+
} else {
71+
return mod.parameters().flattenedValues().reduce(
72+
0,
73+
{
74+
$0 + $1.size
75+
})
76+
}
77+
}.reduce(0, +)
78+
}
79+
}
80+
6081
/// Interface for all LLM Models
6182
public protocol LLMModel: Module {
6283

Tools/llm-tool/LoraCommands.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ struct LoRAModelArguments: ParsableArguments, Sendable {
5858
}
5959

6060
func describe(model: Module) {
61-
let totalParameterCount = model.parameters()
62-
.flattenedValues().map { $0.size }.reduce(0, +)
61+
let totalParameterCount = model.numParameters()
6362
let trainableParameterCount = model.trainableParameters()
6463
.flattenedValues().map { $0.size }.reduce(0, +)
6564

0 commit comments

Comments
 (0)