@@ -51,39 +51,83 @@ func computeBaseFrequency(
51
51
52
52
private class DynamicNTKScalingRoPE : Module {
53
53
let dims : Int
54
- let maxPositionEmbeddings : Int ?
54
+ let maxPositionEmbeddings : Int
55
55
let traditional : Bool
56
- let base : Float
57
- var scale : Float
56
+ var base : Float ?
57
+ let scale : Float
58
58
let ropeType : String
59
59
let ropeScaling : [ String : StringOrNumber ] ?
60
+ var freqs : MLXArray ?
60
61
61
62
init (
62
- dims: Int , maxPositionEmbeddings: Int ? , traditional: Bool = false ,
63
- base: Float = 10000 , scale: Float = 1.0 , ropeType: String = " default " ,
63
+ dims: Int ,
64
+ maxPositionEmbeddings: Int ? ,
65
+ traditional: Bool = false ,
66
+ base: Float = 10000 ,
67
+ scale: Float = 1.0 ,
68
+ ropeType: String = " default " ,
64
69
ropeScaling: [ String : StringOrNumber ] ? = nil
65
70
) {
66
71
self . dims = dims
67
- self . maxPositionEmbeddings = maxPositionEmbeddings
72
+ self . maxPositionEmbeddings = maxPositionEmbeddings ?? 2048
68
73
self . traditional = traditional
69
- self . base = computeBaseFrequency (
70
- base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
74
+ self . base = base
71
75
self . scale = scale
72
76
self . ropeType = ropeType
73
77
self . ropeScaling = ropeScaling
78
+ super. init ( )
79
+ computeFreqs ( )
74
80
}
75
81
76
- func callAsFunction( _ x: MLXArray , offset: Int = 0 ) -> MLXArray {
77
- let seqLen = x. dim ( 1 ) + offset
78
- var base = self . base
79
- if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
80
- let factorAdjustment = Float ( seqLen) / Float( maxPositionEmbeddings) - 1
81
- let dimensionRatio = Float ( dims) / Float( Float ( dims) - 2 )
82
- let adjustedScale = scale * pow( 1 + factorAdjustment, dimensionRatio)
83
- base *= adjustedScale
82
+ private func computeFreqs( ) {
83
+ if ropeType != " llama3 " {
84
+ freqs = nil
85
+ return
86
+ }
87
+
88
+ guard let ropeScaling = ropeScaling,
89
+ case . float( let factor) = ropeScaling [ " factor " ] ,
90
+ case . float( let lowFreqFactor) = ropeScaling [ " low_freq_factor " ] ?? . float( 1.0 ) ,
91
+ case . float( let highFreqFactor) = ropeScaling [ " high_freq_factor " ] ?? . float( 4.0 ) ,
92
+ case . float( let oldContextLen) = ropeScaling [ " original_max_position_embeddings " ]
93
+ ?? . float( 8192 ) ,
94
+ let base
95
+ else {
96
+ freqs = nil
97
+ return
84
98
}
85
- return MLXFast . RoPE (
86
- x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
99
+
100
+ let lowFreqWavelen = oldContextLen / lowFreqFactor
101
+ let highFreqWavelen = oldContextLen / highFreqFactor
102
+
103
+ let indices = MLXArray ( stride ( from: 0 , to: dims, by: 2 ) )
104
+ var frequencies = MLX . pow ( base, indices / Float( dims) )
105
+ let wavelens = 2 * Float. pi * frequencies
106
+
107
+ frequencies = MLX . where (
108
+ wavelens .> MLXArray ( lowFreqWavelen) , frequencies * factor, frequencies)
109
+ let isMediumFreq = MLX . logicalAnd (
110
+ wavelens .> MLXArray ( highFreqWavelen) ,
111
+ wavelens .< MLXArray ( lowFreqWavelen)
112
+ )
113
+ let smoothFactors =
114
+ ( oldContextLen / wavelens - lowFreqFactor) / ( highFreqFactor - lowFreqFactor)
115
+ let smoothFreqs = frequencies / ( ( 1 - smoothFactors) / factor + smoothFactors)
116
+
117
+ freqs = MLX . where ( isMediumFreq, smoothFreqs, frequencies)
118
+ self . base = nil
119
+ }
120
+
121
+ func callAsFunction( _ x: MLXArray , offset: Int = 0 ) -> MLXArray {
122
+ MLXFast . RoPE (
123
+ x,
124
+ dimensions: dims,
125
+ traditional: traditional,
126
+ base: base,
127
+ scale: scale,
128
+ offset: offset,
129
+ freqs: freqs
130
+ )
87
131
}
88
132
}
89
133
0 commit comments