@@ -1034,23 +1034,6 @@ private class Gemma3nDecoderLayer: Module {
1034
1034
}
1035
1035
}
1036
1036
1037
- private class Gemma3nTextScaledWordEmbedding : Module , UnaryLayer {
1038
- @ModuleInfo var weight : MLXArray
1039
- let embedScale : Float
1040
-
1041
- init ( numEmbeddings: Int , embeddingDim: Int , embedScale: Float = 1.0 ) {
1042
- self . embedScale = embedScale
1043
- self . _weight. wrappedValue = MLXRandom . normal ( [ numEmbeddings, embeddingDim] )
1044
- super. init ( )
1045
- }
1046
-
1047
- func callAsFunction( _ x: MLXArray ) -> MLXArray {
1048
- let indices = x. asType ( . int32)
1049
- let embeddings = take ( weight, indices, axis: 0 )
1050
- return embeddings * MLXArray( embedScale, dtype: . float32) . asType ( weight. dtype)
1051
- }
1052
- }
1053
-
1054
1037
private class Gemma3Model : Module {
1055
1038
let config : TextConfig
1056
1039
let hiddenSize : Int
@@ -1059,11 +1042,12 @@ private class Gemma3Model: Module {
1059
1042
let numHiddenLayers : Int
1060
1043
private let _perLayerProjectionScale : MLXArray
1061
1044
private let _perLayerInputScale : MLXArray
1045
+ private let _embedTokensScale : Float
1046
+ private let _embedTokensPerLayerScale : Float
1062
1047
1063
- @ModuleInfo ( key: " embed_tokens " ) var embedTokens : Gemma3nTextScaledWordEmbedding
1064
- @ModuleInfo ( key: " layers " ) var layers : [ Gemma3nDecoderLayer ] // This is correct!
1065
- @ModuleInfo ( key: " embed_tokens_per_layer " ) var embedTokensPerLayer :
1066
- Gemma3nTextScaledWordEmbedding
1048
+ @ModuleInfo ( key: " embed_tokens " ) var embedTokens : Embedding
1049
+ @ModuleInfo ( key: " layers " ) var layers : [ Gemma3nDecoderLayer ]
1050
+ @ModuleInfo ( key: " embed_tokens_per_layer " ) var embedTokensPerLayer : Embedding
1067
1051
@ModuleInfo ( key: " per_layer_model_projection " ) var perLayerModelProjection : Linear
1068
1052
@ModuleInfo ( key: " per_layer_projection_norm " ) var perLayerProjectionNorm :
1069
1053
Gemma3nRMSNormWithScale
@@ -1084,21 +1068,21 @@ private class Gemma3Model: Module {
1084
1068
1085
1069
assert ( vocabSize > 0 )
1086
1070
1087
- self . _embedTokens. wrappedValue = Gemma3nTextScaledWordEmbedding (
1088
- numEmbeddings: config. vocabSize,
1089
- embeddingDim: config. hiddenSize,
1090
- embedScale: pow ( Float ( config. hiddenSize) , 0.5 )
1071
+ self . _embedTokens. wrappedValue = Embedding (
1072
+ embeddingCount: config. vocabSize,
1073
+ dimensions: config. hiddenSize,
1091
1074
)
1075
+ self . _embedTokensScale = pow ( Float ( config. hiddenSize) , 0.5 )
1092
1076
1093
1077
self . _layers. wrappedValue = ( 0 ..< config. numHiddenLayers) . map { layerIdx in
1094
1078
Gemma3nDecoderLayer ( config: config, layerIdx: layerIdx)
1095
1079
}
1096
1080
1097
- self . _embedTokensPerLayer. wrappedValue = Gemma3nTextScaledWordEmbedding (
1098
- numEmbeddings: config. vocabSizePerLayerInput,
1099
- embeddingDim: config. numHiddenLayers * config. hiddenSizePerLayerInput,
1100
- embedScale: pow ( Float ( config. hiddenSizePerLayerInput) , 0.5 )
1081
+ self . _embedTokensPerLayer. wrappedValue = Embedding (
1082
+ embeddingCount: config. vocabSizePerLayerInput,
1083
+ dimensions: config. numHiddenLayers * config. hiddenSizePerLayerInput,
1101
1084
)
1085
+ self . _embedTokensPerLayerScale = pow ( Float ( config. hiddenSizePerLayerInput) , 0.5 )
1102
1086
1103
1087
self . _perLayerModelProjection. wrappedValue = Linear (
1104
1088
config. hiddenSize,
@@ -1150,6 +1134,7 @@ private class Gemma3Model: Module {
1150
1134
h = inputsEmbeds
1151
1135
} else if let inputs {
1152
1136
h = embedTokens ( inputs)
1137
+ h = ( h * MLXArray( _embedTokensScale, dtype: . float32) ) . asType ( h. dtype)
1153
1138
} else {
1154
1139
fatalError ( " Either inputs or inputsEmbeds must be provided " )
1155
1140
}
@@ -1253,7 +1238,10 @@ private class Gemma3Model: Module {
1253
1238
inputIds .< vocabSizePerLayerInput
1254
1239
)
1255
1240
let tokens = MLX . where ( perLayerInputsMask, inputIds, MLXArray . zeros ( like: inputIds) )
1256
- let result = embedTokensPerLayer ( tokens) . reshaped (
1241
+ var result = embedTokensPerLayer ( tokens)
1242
+ result = ( result * MLXArray( _embedTokensPerLayerScale, dtype: . float32) ) . asType (
1243
+ result. dtype)
1244
+ result = result. reshaped (
1257
1245
Array ( inputIds. shape) + [ config. numHiddenLayers, config. hiddenSizePerLayerInput]
1258
1246
)
1259
1247
return result
0 commit comments