Skip to content

Commit 59ca81f

Browse files
committed
Correctly scale text embeddings for quantized models (mlx-vlm #397)
1 parent f83d531 commit 59ca81f

File tree

1 file changed

+18
-30
lines changed

1 file changed

+18
-30
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,23 +1034,6 @@ private class Gemma3nDecoderLayer: Module {
10341034
}
10351035
}
10361036

1037-
private class Gemma3nTextScaledWordEmbedding: Module, UnaryLayer {
1038-
@ModuleInfo var weight: MLXArray
1039-
let embedScale: Float
1040-
1041-
init(numEmbeddings: Int, embeddingDim: Int, embedScale: Float = 1.0) {
1042-
self.embedScale = embedScale
1043-
self._weight.wrappedValue = MLXRandom.normal([numEmbeddings, embeddingDim])
1044-
super.init()
1045-
}
1046-
1047-
func callAsFunction(_ x: MLXArray) -> MLXArray {
1048-
let indices = x.asType(.int32)
1049-
let embeddings = take(weight, indices, axis: 0)
1050-
return embeddings * MLXArray(embedScale, dtype: .float32).asType(weight.dtype)
1051-
}
1052-
}
1053-
10541037
private class Gemma3Model: Module {
10551038
let config: TextConfig
10561039
let hiddenSize: Int
@@ -1059,11 +1042,12 @@ private class Gemma3Model: Module {
10591042
let numHiddenLayers: Int
10601043
private let _perLayerProjectionScale: MLXArray
10611044
private let _perLayerInputScale: MLXArray
1045+
private let _embedTokensScale: Float
1046+
private let _embedTokensPerLayerScale: Float
10621047

1063-
@ModuleInfo(key: "embed_tokens") var embedTokens: Gemma3nTextScaledWordEmbedding
1064-
@ModuleInfo(key: "layers") var layers: [Gemma3nDecoderLayer] // This is correct!
1065-
@ModuleInfo(key: "embed_tokens_per_layer") var embedTokensPerLayer:
1066-
Gemma3nTextScaledWordEmbedding
1048+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
1049+
@ModuleInfo(key: "layers") var layers: [Gemma3nDecoderLayer]
1050+
@ModuleInfo(key: "embed_tokens_per_layer") var embedTokensPerLayer: Embedding
10671051
@ModuleInfo(key: "per_layer_model_projection") var perLayerModelProjection: Linear
10681052
@ModuleInfo(key: "per_layer_projection_norm") var perLayerProjectionNorm:
10691053
Gemma3nRMSNormWithScale
@@ -1084,21 +1068,21 @@ private class Gemma3Model: Module {
10841068

10851069
assert(vocabSize > 0)
10861070

1087-
self._embedTokens.wrappedValue = Gemma3nTextScaledWordEmbedding(
1088-
numEmbeddings: config.vocabSize,
1089-
embeddingDim: config.hiddenSize,
1090-
embedScale: pow(Float(config.hiddenSize), 0.5)
1071+
self._embedTokens.wrappedValue = Embedding(
1072+
embeddingCount: config.vocabSize,
1073+
dimensions: config.hiddenSize,
10911074
)
1075+
self._embedTokensScale = pow(Float(config.hiddenSize), 0.5)
10921076

10931077
self._layers.wrappedValue = (0 ..< config.numHiddenLayers).map { layerIdx in
10941078
Gemma3nDecoderLayer(config: config, layerIdx: layerIdx)
10951079
}
10961080

1097-
self._embedTokensPerLayer.wrappedValue = Gemma3nTextScaledWordEmbedding(
1098-
numEmbeddings: config.vocabSizePerLayerInput,
1099-
embeddingDim: config.numHiddenLayers * config.hiddenSizePerLayerInput,
1100-
embedScale: pow(Float(config.hiddenSizePerLayerInput), 0.5)
1081+
self._embedTokensPerLayer.wrappedValue = Embedding(
1082+
embeddingCount: config.vocabSizePerLayerInput,
1083+
dimensions: config.numHiddenLayers * config.hiddenSizePerLayerInput,
11011084
)
1085+
self._embedTokensPerLayerScale = pow(Float(config.hiddenSizePerLayerInput), 0.5)
11021086

11031087
self._perLayerModelProjection.wrappedValue = Linear(
11041088
config.hiddenSize,
@@ -1150,6 +1134,7 @@ private class Gemma3Model: Module {
11501134
h = inputsEmbeds
11511135
} else if let inputs {
11521136
h = embedTokens(inputs)
1137+
h = (h * MLXArray(_embedTokensScale, dtype: .float32)).asType(h.dtype)
11531138
} else {
11541139
fatalError("Either inputs or inputsEmbeds must be provided")
11551140
}
@@ -1253,7 +1238,10 @@ private class Gemma3Model: Module {
12531238
inputIds .< vocabSizePerLayerInput
12541239
)
12551240
let tokens = MLX.where(perLayerInputsMask, inputIds, MLXArray.zeros(like: inputIds))
1256-
let result = embedTokensPerLayer(tokens).reshaped(
1241+
var result = embedTokensPerLayer(tokens)
1242+
result = (result * MLXArray(_embedTokensPerLayerScale, dtype: .float32)).asType(
1243+
result.dtype)
1244+
result = result.reshaped(
12571245
Array(inputIds.shape) + [config.numHiddenLayers, config.hiddenSizePerLayerInput]
12581246
)
12591247
return result

0 commit comments

Comments
 (0)