Skip to content

Commit 43d2e98

Browse files
Improve adapters usage (LoRA, DoRA) (#326)
* Improve adapters usage * Use quantized matmul * Improve adapters factory
1 parent f7da396 commit 43d2e98

File tree

9 files changed

+762
-61
lines changed

9 files changed

+762
-61
lines changed

Libraries/MLXLLM/LoraTrain.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ public enum LoRATrain {
206206
let children = layer.children()
207207
for key in keys {
208208
if let item = children[key], case .value(let child) = item {
209-
if let lora = child as? LoRAConvertToLinear {
210-
update[key] = .value(lora.toLinear(deQuantize: deQuantize))
209+
if let lora = child as? LoRALayer {
210+
update[key] = .value(lora.fused())
211211
}
212212
}
213213
}
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
//
2+
// DoRA+Layers.swift
3+
// mlx-libraries
4+
//
5+
// Created by Ivan Petrukha on 02.06.2025.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXLinalg
11+
import MLXNN
12+
import MLXRandom
13+
14+
/// Performs the forward pass for a DoRA linear layer.
15+
private func forward(
16+
x: MLXArray, y: MLXArray,
17+
weight: MLXArray, bias: MLXArray?,
18+
loraA: MLXArray, loraB: MLXArray,
19+
scale: Float, magnitude: MLXArray
20+
) -> MLXArray {
21+
let z = matmul(matmul(x, loraA), loraB)
22+
var out = y + (scale * z).asType(x.dtype)
23+
24+
let adapted = weight + matmul(scale * loraB.T, loraA.T)
25+
let denom = norm(adapted, axis: 1)
26+
out *= (magnitude / denom).asType(x.dtype)
27+
28+
return if let bias {
29+
out + bias
30+
} else {
31+
out
32+
}
33+
}
34+
35+
/// Fuses the base weights with the DoRA parameters.
36+
private func fuse(
37+
weight: MLXArray,
38+
loraA: MLXArray, loraB: MLXArray,
39+
scale: Float, magnitude: MLXArray
40+
) -> MLXArray {
41+
let loraA = loraA.T.asType(weight.dtype)
42+
let loraB = (scale * loraB.T).asType(weight.dtype)
43+
44+
var adapted = weight + matmul(loraB, loraA)
45+
let denom = norm(adapted, axis: 1)
46+
adapted *= (magnitude / denom).reshaped([-1, 1])
47+
48+
return adapted
49+
}
50+
51+
/// Filters out DoRA-specific parameters from a list of module keys.
52+
private func filterFreezeKeys(from module: Module, keys: [String]?) -> [String] {
53+
return
54+
(keys
55+
?? module.filterMap(filter: type(of: module).filterLocalParameters)
56+
.flattened()
57+
.map { $0.0 })
58+
.filter { !["lora_a", "lora_b", "m"].contains($0) }
59+
}
60+
61+
/// Implementation of DoRA `Linear` replacement layer.
62+
///
63+
/// This layer implements DoRA (Weight-Decomposed Low-Rank Adaptation) for `Linear` layers.
64+
///
65+
/// ``QDoRALinear`` is the equivalent class for `QuantizedLinear`.
66+
public class DoRALinear: Linear, LoRALayer {
67+
68+
let scale: Float
69+
70+
@ParameterInfo(key: "lora_a") var loraA: MLXArray
71+
@ParameterInfo(key: "lora_b") var loraB: MLXArray
72+
@ParameterInfo(key: "m") var magnitude: MLXArray
73+
74+
required public init(linear: Linear, rank: Int = 8, scale: Float = 20.0) {
75+
let (outputDimensions, inputDimensions) = linear.shape
76+
let loraScale = 1 / sqrt(Float(inputDimensions))
77+
78+
self.scale = scale
79+
self._loraA.wrappedValue = MLXRandom.uniform(
80+
low: -loraScale, high: loraScale, [inputDimensions, rank])
81+
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
82+
self._magnitude.wrappedValue = MLXLinalg.norm(linear.weight, axis: 1)
83+
84+
super.init(weight: linear.weight, bias: linear.bias)
85+
86+
freeze()
87+
}
88+
89+
public static func from(linear: Linear, rank: Int = 8, scale: Float = 20.0) -> LoRALayer {
90+
if let linear = linear as? QuantizedLinear {
91+
QDoRALinear(linear: linear, rank: rank, scale: scale)
92+
} else {
93+
DoRALinear(linear: linear, rank: rank, scale: scale)
94+
}
95+
}
96+
97+
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
98+
throws
99+
{
100+
let keys = filterFreezeKeys(from: self, keys: keys)
101+
try super.freeze(recursive: recursive, keys: keys, strict: strict)
102+
}
103+
104+
public func fused() -> Module {
105+
Linear(
106+
weight: fuse(
107+
weight: weight, loraA: loraA, loraB: loraB, scale: scale, magnitude: magnitude),
108+
bias: bias
109+
)
110+
}
111+
112+
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
113+
let y = matmul(x, weight.T)
114+
return forward(
115+
x: x, y: y,
116+
weight: weight, bias: bias,
117+
loraA: loraA, loraB: loraB,
118+
scale: scale, magnitude: magnitude
119+
)
120+
}
121+
}
122+
123+
/// Implementation of DoRA `QuantizedLinear` replacement layer.
124+
///
125+
/// See ``DoRALinear`` (equivalent class for `Linear` layers) for more information.
126+
///
127+
/// ### See Also
128+
/// - ``DoRALinear``
129+
public class QDoRALinear: QuantizedLinear, LoRALayer {
130+
131+
let scale: Float
132+
133+
@ParameterInfo(key: "lora_a") var loraA: MLXArray
134+
@ParameterInfo(key: "lora_b") var loraB: MLXArray
135+
@ParameterInfo(key: "m") var magnitude: MLXArray
136+
137+
required public init(linear: QuantizedLinear, rank: Int = 8, scale: Float = 20.0) {
138+
let (outputDimensions, inputDimensions) = linear.shape
139+
let loraScale = 1 / sqrt(Float(inputDimensions))
140+
141+
self.scale = scale
142+
self._loraA.wrappedValue = MLXRandom.uniform(
143+
low: -loraScale, high: loraScale, [inputDimensions, rank])
144+
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
145+
self._magnitude.wrappedValue = MLXLinalg.norm(linear.dequantizedWeight, axis: 1)
146+
147+
super.init(
148+
weight: linear.weight, bias: linear.bias,
149+
scales: linear.scales, biases: linear.biases,
150+
groupSize: linear.groupSize, bits: linear.bits
151+
)
152+
153+
freeze()
154+
}
155+
156+
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
157+
throws
158+
{
159+
let keys = filterFreezeKeys(from: self, keys: keys)
160+
try super.freeze(recursive: recursive, keys: keys, strict: strict)
161+
}
162+
163+
public func fused() -> Module {
164+
QuantizedLinear(
165+
weight: fuse(
166+
weight: dequantizedWeight, loraA: loraA, loraB: loraB, scale: scale,
167+
magnitude: magnitude),
168+
bias: bias, groupSize: groupSize, bits: bits
169+
)
170+
}
171+
172+
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
173+
let y = quantizedMatmul(
174+
x, weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits)
175+
return forward(
176+
x: x, y: y,
177+
weight: dequantizedWeight, bias: bias,
178+
loraA: loraA, loraB: loraB,
179+
scale: scale, magnitude: magnitude
180+
)
181+
}
182+
}

Libraries/MLXLMCommon/Lora.swift renamed to Libraries/MLXLMCommon/Adapters/LoRA/LoRA+Layers.swift

Lines changed: 20 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,7 @@ import Foundation
44
import MLX
55
import MLXNN
66
import MLXOptimizers
7-
import Tokenizers
8-
9-
/// Layers to apply LoRA adapters to.
10-
///
11-
/// This is the value returned by ``LoRAModel/loraLinearLayers()``.
12-
public typealias LoRALinearLayers = [(Module, [String])]
13-
14-
public protocol LoRAModel {
15-
/// Return the layers and keys to apply LoRA adapters to.
16-
///
17-
/// For example this might apply the adapters to the `q` an `v` projections in the
18-
/// Attention layers:
19-
///
20-
/// ```swift
21-
/// model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
22-
/// ```
23-
///
24-
/// It is not required that a model implement this protocol to have LoRA adapters applied, but
25-
/// the command line driver example uses this to produce the ``LoRALinearLayers``.
26-
///
27-
/// ### See Also
28-
/// - ``LoRATrain/convert(model:layers:)``
29-
func loraLinearLayers() -> LoRALinearLayers
30-
31-
/// Return a suffix of the layers and keys to apply LoRA adapters to.
32-
///
33-
/// See ``loraLinearLayers()``
34-
func loraLinearLayers(_ count: Int) -> LoRALinearLayers
35-
}
36-
37-
extension LoRAModel {
38-
public func loraLinearLayers(_ count: Int) -> LoRALinearLayers {
39-
loraLinearLayers().suffix(count)
40-
}
41-
}
42-
43-
/// Protocol for LoRA implementations that provides a method for converting back to a `Linear`
44-
/// (or subtype).
45-
///
46-
/// This is normally called via ``LoRATrain/fuse(model:layers:deQuantize:)``
47-
public protocol LoRAConvertToLinear {
48-
func toLinear(deQuantize: Bool) -> Linear
49-
}
7+
import MLXRandom
508

519
/// Implementation of LoRA `Linear` replacement layer.
5210
///
@@ -67,7 +25,7 @@ public protocol LoRAConvertToLinear {
6725
/// - ``QLoRALinear``
6826
/// - ``LoRATrain/convert(model:layers:)``
6927
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
70-
public class LoRALinear: Linear, LoRAConvertToLinear {
28+
public class LoRALinear: Linear, LoRALayer {
7129

7230
let scale: Float
7331

@@ -113,12 +71,13 @@ public class LoRALinear: Linear, LoRAConvertToLinear {
11371
/// ### See Also
11472
/// - ``LoRATrain/convert(model:layers:)``
11573
/// - ``QLoRALinear/from(linear:rank:)``
116-
public static func from(linear: Linear, rank: Int = 8) -> Linear {
74+
public static func from(linear: Linear, rank: Int = 8, scale: Float = 20.0) -> LoRALayer {
11775
if let linear = linear as? QuantizedLinear {
118-
return QLoRALinear.from(linear: linear, rank: rank)
76+
return QLoRALinear.from(linear: linear, rank: rank, scale: scale)
11977
}
12078
let (outputDimensions, inputDimensions) = linear.shape
121-
return LoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
79+
return LoRALinear(
80+
inputDimensions, outputDimensions, rank: rank, scale: scale, linear: linear)
12281
}
12382

12483
/// Convert back into a fused `Linear` layer.
@@ -129,7 +88,7 @@ public class LoRALinear: Linear, LoRAConvertToLinear {
12988
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
13089
/// - ``LoRAConvertToLinear``
13190
/// - ``QLoRALinear/toLinear(deQuantize:)``
132-
public func toLinear(deQuantize: Bool = false) -> Linear {
91+
public func fused() -> Module {
13392
let dtype = weight.dtype
13493
let loraB = (scale * loraB.T).asType(dtype)
13594
let loraA = loraA.T.asType(dtype)
@@ -146,7 +105,7 @@ public class LoRALinear: Linear, LoRAConvertToLinear {
146105
/// Implementation of LoRA `QuantizedLinear` replacement layer.
147106
///
148107
/// See ``LoRALinear`` (equivalent class for `Linear` layers) for more information.
149-
public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear {
108+
public class QLoRALinear: QuantizedLinear, LoRALayer {
150109

151110
let scale: Float
152111

@@ -196,9 +155,12 @@ public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear {
196155
/// ### See Also
197156
/// - ``LoRATrain/convert(model:layers:)``
198157
/// - ``LoRALinear/from(linear:rank:)``
199-
public static func from(linear: QuantizedLinear, rank: Int = 8) -> Linear {
158+
public static func from(linear: QuantizedLinear, rank: Int = 8, scale: Float = 20.0)
159+
-> LoRALayer
160+
{
200161
let (outputDimensions, inputDimensions) = linear.shape
201-
return QLoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
162+
return QLoRALinear(
163+
inputDimensions, outputDimensions, rank: rank, scale: scale, linear: linear)
202164
}
203165

204166
/// Convert back into a fused `QuantizedLinear` layer.
@@ -207,17 +169,16 @@ public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear {
207169
///
208170
/// ### See Also
209171
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
210-
public func toLinear(deQuantize: Bool = false) -> Linear {
211-
// convert back into full weights
212-
let weight = dequantized(
213-
weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits)
214-
172+
public func fused() -> Module {
173+
let weight = dequantizedWeight
215174
let loraB = (scale * loraB.T).asType(.float16)
216175
let loraA = loraA.T.asType(.float16)
217-
218-
// convert back into quantized
219176
return QuantizedLinear(
220-
weight: weight + matmul(loraB, loraA), bias: bias, groupSize: groupSize, bits: bits)
177+
weight: weight + matmul(loraB, loraA),
178+
bias: bias,
179+
groupSize: groupSize,
180+
bits: bits
181+
)
221182
}
222183

223184
public override func callAsFunction(_ x: MLXArray) -> MLXArray {

0 commit comments

Comments
 (0)