Skip to content

Commit 1db9d3a

Browse files
authored
fix #291 -- support heterogenous quant config (#293)
- support per layer quant config
1 parent 39085b8 commit 1db9d3a

File tree

10 files changed

+462
-65
lines changed

10 files changed

+462
-65
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
5+
/// Base ``LanguageModel`` configuration -- provides `modelType`
6+
/// and `quantization` (used in loading the model).
7+
///
8+
/// This is used by ``ModelFactory/load(hub:configuration:progressHandler:)``
9+
/// to determine the type of model to load.
10+
public struct BaseConfiguration: Codable, Sendable {
11+
public let modelType: String
12+
13+
public struct Quantization: Codable, Sendable, Equatable {
14+
public init(groupSize: Int, bits: Int) {
15+
self.groupSize = groupSize
16+
self.bits = bits
17+
}
18+
19+
public let groupSize: Int
20+
public let bits: Int
21+
22+
public var asTuple: (Int, Int) { (groupSize, bits) }
23+
24+
enum CodingKeys: String, CodingKey {
25+
case groupSize = "group_size"
26+
case bits = "bits"
27+
}
28+
}
29+
30+
/// handling instructions for ``PerLayerQuantization``
31+
public enum QuantizationOption: Sendable {
32+
case skip
33+
case quantize(Quantization)
34+
}
35+
36+
/// Per-layer ``Quantization`` values with optional default.
37+
public struct PerLayerQuantization: Sendable {
38+
public var quantization: Quantization? = nil
39+
public var perLayerQuantization: [String: QuantizationOption]
40+
41+
public init(
42+
quantization: BaseConfiguration.Quantization? = nil,
43+
perLayerQuantization: [String: BaseConfiguration.QuantizationOption]
44+
) {
45+
self.quantization = quantization
46+
self.perLayerQuantization = perLayerQuantization
47+
}
48+
49+
/// The quantization to apply for the given layer name or nil for no quantization.
50+
public func quantization(layer: String) -> Quantization? {
51+
if let perLayer = perLayerQuantization[layer] {
52+
switch perLayer {
53+
case .skip:
54+
return nil
55+
case .quantize(let quantization):
56+
return quantization
57+
}
58+
} else {
59+
return quantization
60+
}
61+
}
62+
}
63+
64+
/// Special codable to support a mixed key: Int / key: Quantization
65+
/// structure for hereogenous quantization, e.g.
66+
///
67+
/// ```
68+
/// "quantization": {
69+
/// "group_size": 64,
70+
/// "bits": 4,
71+
/// "model.embed_tokens": {
72+
/// "group_size": 32,
73+
/// "bits": 4
74+
/// },
75+
/// "model.layers.0.self_attn.q_norm": false,
76+
/// ```
77+
///
78+
/// This mixed type structure requires manual decoding.
79+
struct QuantizationContainer: Codable, Sendable {
80+
var quantization: Quantization
81+
var perLayerQuantization: PerLayerQuantization
82+
83+
// based on Dictionary's coding key
84+
internal struct _DictionaryCodingKey: CodingKey {
85+
internal let stringValue: String
86+
internal let intValue: Int?
87+
88+
internal init(stringValue: String) {
89+
self.stringValue = stringValue
90+
self.intValue = Int(stringValue)
91+
}
92+
93+
internal init(intValue: Int) {
94+
self.stringValue = "\(intValue)"
95+
self.intValue = intValue
96+
}
97+
}
98+
99+
init(from decoder: any Decoder) throws {
100+
// handle the embedded Quantization
101+
self.quantization = try Quantization(from: decoder)
102+
103+
// and the interleaved per-layer values
104+
var perLayerQuantization = [String: QuantizationOption]()
105+
let container = try decoder.container(keyedBy: _DictionaryCodingKey.self)
106+
for key in container.allKeys {
107+
switch key.stringValue {
108+
case Quantization.CodingKeys.groupSize.rawValue: continue
109+
case Quantization.CodingKeys.bits.rawValue: continue
110+
111+
default:
112+
if let f = try? container.decode(Bool.self, forKey: key) {
113+
if !f {
114+
perLayerQuantization[key.stringValue] = .skip
115+
}
116+
} else {
117+
perLayerQuantization[key.stringValue] = .quantize(
118+
try container.decode(Quantization.self, forKey: key))
119+
}
120+
}
121+
}
122+
self.perLayerQuantization = PerLayerQuantization(
123+
quantization: quantization, perLayerQuantization: perLayerQuantization)
124+
}
125+
126+
func encode(to encoder: any Encoder) throws {
127+
try quantization.encode(to: encoder)
128+
129+
var container = encoder.container(keyedBy: _DictionaryCodingKey.self)
130+
for (key, value) in perLayerQuantization.perLayerQuantization {
131+
switch value {
132+
case .skip:
133+
try container.encode(false, forKey: .init(stringValue: key))
134+
case .quantize(let q):
135+
try container.encode(q, forKey: .init(stringValue: key))
136+
}
137+
}
138+
}
139+
}
140+
141+
var quantizationContainer: QuantizationContainer?
142+
143+
@available(*, deprecated, message: "Please use perLayerQuantization instead")
144+
public var quantization: Quantization? {
145+
quantizationContainer?.quantization
146+
}
147+
148+
public var perLayerQuantization: PerLayerQuantization? {
149+
quantizationContainer?.perLayerQuantization
150+
}
151+
152+
enum CodingKeys: String, CodingKey {
153+
case modelType = "model_type"
154+
case quantizationContainer = "quantization"
155+
}
156+
}

Libraries/Embedders/Configuration.swift

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -110,29 +110,3 @@ public struct ModelType: RawRepresentable, Codable, Sendable {
110110
try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
111111
}
112112
}
113-
114-
public struct BaseConfiguration: Codable, Sendable {
115-
public let modelType: ModelType
116-
117-
public struct Quantization: Codable, Sendable {
118-
public init(groupSize: Int, bits: Int) {
119-
self.groupSize = groupSize
120-
self.bits = bits
121-
}
122-
123-
let groupSize: Int
124-
let bits: Int
125-
126-
enum CodingKeys: String, CodingKey {
127-
case groupSize = "group_size"
128-
case bits = "bits"
129-
}
130-
}
131-
132-
public var quantization: Quantization?
133-
134-
enum CodingKeys: String, CodingKey {
135-
case modelType = "model_type"
136-
case quantization
137-
}
138-
}

Libraries/Embedders/Load.swift

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
6262
let baseConfig = try JSONDecoder().decode(
6363
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
6464

65-
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
65+
let modelType = ModelType(rawValue: baseConfig.modelType)
66+
let model = try modelType.createModel(configuration: configurationURL)
6667

6768
// load the weights
6869
var weights = [String: MLXArray]()
@@ -81,6 +82,16 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
8182
weights = model.sanitize(weights: weights)
8283

8384
// quantize if needed
85+
if let perLayerQuantization = baseConfig.perLayerQuantization {
86+
quantize(model: model) { path, module in
87+
if weights["\(path).scales"] != nil {
88+
return perLayerQuantization.quantization(layer: path)?.asTuple
89+
} else {
90+
return nil
91+
}
92+
}
93+
}
94+
8495
if let quantization = baseConfig.quantization {
8596
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
8697
path, module in
@@ -108,3 +119,26 @@ public func loadModelContainer(
108119
return try await ModelContainer(
109120
hub: hub, modelDirectory: modelDirectory, configuration: configuration)
110121
}
122+
123+
// TODO remove once mlx-swift update is adopted
124+
func quantize(
125+
model: Module,
126+
filter: (String, Module) -> (groupSize: Int, bits: Int)?,
127+
apply: (Module, Int, Int) -> Module? = quantizeSingle(layer:groupSize:bits:)
128+
) {
129+
let updates =
130+
model
131+
.leafModules()
132+
.flattened()
133+
.compactMap { (path, m) -> (String, Module)? in
134+
if let (groupSize, bits) = filter(path, m) {
135+
if let quantized = apply(m, groupSize, bits) {
136+
return (path, quantized)
137+
}
138+
}
139+
140+
return nil
141+
}
142+
143+
model.update(modules: ModuleChildren.unflattened(updates))
144+
}

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ public class LLMModelFactory: ModelFactory {
325325

326326
// apply the weights to the bare model
327327
try loadWeights(
328-
modelDirectory: modelDirectory, model: model, quantization: baseConfig.quantization)
328+
modelDirectory: modelDirectory, model: model,
329+
perLayerQuantization: baseConfig.perLayerQuantization)
329330

330331
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
331332

0 commit comments

Comments
 (0)