Skip to content

Commit 2dbe65d

Browse files
Fix DynamicNTKScalingRoPE (#154)
* Fix DynamicNTKScalingRoPE * Fix base
1 parent 7baf9bc commit 2dbe65d

File tree

1 file changed

+62
-18
lines changed

1 file changed

+62
-18
lines changed

Libraries/LLM/Models/Llama.swift

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,39 +51,83 @@ func computeBaseFrequency(
5151

5252
private class DynamicNTKScalingRoPE: Module {
5353
let dims: Int
54-
let maxPositionEmbeddings: Int?
54+
let maxPositionEmbeddings: Int
5555
let traditional: Bool
56-
let base: Float
57-
var scale: Float
56+
var base: Float?
57+
let scale: Float
5858
let ropeType: String
5959
let ropeScaling: [String: StringOrNumber]?
60+
var freqs: MLXArray?
6061

6162
init(
62-
dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false,
63-
base: Float = 10000, scale: Float = 1.0, ropeType: String = "default",
63+
dims: Int,
64+
maxPositionEmbeddings: Int?,
65+
traditional: Bool = false,
66+
base: Float = 10000,
67+
scale: Float = 1.0,
68+
ropeType: String = "default",
6469
ropeScaling: [String: StringOrNumber]? = nil
6570
) {
6671
self.dims = dims
67-
self.maxPositionEmbeddings = maxPositionEmbeddings
72+
self.maxPositionEmbeddings = maxPositionEmbeddings ?? 2048
6873
self.traditional = traditional
69-
self.base = computeBaseFrequency(
70-
base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
74+
self.base = base
7175
self.scale = scale
7276
self.ropeType = ropeType
7377
self.ropeScaling = ropeScaling
78+
super.init()
79+
computeFreqs()
7480
}
7581

76-
func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
77-
let seqLen = x.dim(1) + offset
78-
var base = self.base
79-
if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
80-
let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1
81-
let dimensionRatio = Float(dims) / Float(Float(dims) - 2)
82-
let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio)
83-
base *= adjustedScale
82+
private func computeFreqs() {
83+
if ropeType != "llama3" {
84+
freqs = nil
85+
return
86+
}
87+
88+
guard let ropeScaling = ropeScaling,
89+
case .float(let factor) = ropeScaling["factor"],
90+
case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0),
91+
case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0),
92+
case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"]
93+
?? .float(8192),
94+
let base
95+
else {
96+
freqs = nil
97+
return
8498
}
85-
return MLXFast.RoPE(
86-
x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
99+
100+
let lowFreqWavelen = oldContextLen / lowFreqFactor
101+
let highFreqWavelen = oldContextLen / highFreqFactor
102+
103+
let indices = MLXArray(stride(from: 0, to: dims, by: 2))
104+
var frequencies = MLX.pow(base, indices / Float(dims))
105+
let wavelens = 2 * Float.pi * frequencies
106+
107+
frequencies = MLX.where(
108+
wavelens .> MLXArray(lowFreqWavelen), frequencies * factor, frequencies)
109+
let isMediumFreq = MLX.logicalAnd(
110+
wavelens .> MLXArray(highFreqWavelen),
111+
wavelens .< MLXArray(lowFreqWavelen)
112+
)
113+
let smoothFactors =
114+
(oldContextLen / wavelens - lowFreqFactor) / (highFreqFactor - lowFreqFactor)
115+
let smoothFreqs = frequencies / ((1 - smoothFactors) / factor + smoothFactors)
116+
117+
freqs = MLX.where(isMediumFreq, smoothFreqs, frequencies)
118+
self.base = nil
119+
}
120+
121+
func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
122+
MLXFast.RoPE(
123+
x,
124+
dimensions: dims,
125+
traditional: traditional,
126+
base: base,
127+
scale: scale,
128+
offset: offset,
129+
freqs: freqs
130+
)
87131
}
88132
}
89133

0 commit comments

Comments
 (0)