@@ -1146,18 +1146,18 @@ private class Gemma3Model: Module {
1146
1146
perLayerInputs: MLXArray ? = nil
1147
1147
) -> MLXArray {
1148
1148
var h : MLXArray
1149
- if let inputsEmbeds = inputsEmbeds {
1149
+ if let inputsEmbeds {
1150
1150
h = inputsEmbeds
1151
- } else if let inputs = inputs {
1151
+ } else if let inputs {
1152
1152
h = embedTokens ( inputs)
1153
1153
} else {
1154
1154
fatalError ( " Either inputs or inputsEmbeds must be provided " )
1155
1155
}
1156
1156
1157
1157
let perLayerInputsProcessed : MLXArray
1158
- if let perLayerInputs = perLayerInputs {
1158
+ if let perLayerInputs {
1159
1159
perLayerInputsProcessed = perLayerInputs
1160
- } else if let inputs = inputs {
1160
+ } else if let inputs {
1161
1161
perLayerInputsProcessed = getPerLayerInputs ( inputs)
1162
1162
} else {
1163
1163
fatalError ( " Cannot generate per layer inputs without input ids " )
@@ -1213,7 +1213,7 @@ private class Gemma3Model: Module {
1213
1213
== " global_attention "
1214
1214
1215
1215
let localMask : MLXFast . ScaledDotProductAttentionMaskMode
1216
- if let mask = mask {
1216
+ if let mask {
1217
1217
localMask = mask
1218
1218
} else if isGlobal {
1219
1219
localMask = fullMask
@@ -1437,9 +1437,9 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
1437
1437
}
1438
1438
1439
1439
let embNorm : MLXArray
1440
- if let inputsEmbeds = inputsEmbeds {
1440
+ if let inputsEmbeds {
1441
1441
embNorm = softEmbeddingNorm ( inputsEmbeds)
1442
- } else if let inputIds = inputIds {
1442
+ } else if let inputIds {
1443
1443
let hardEmb = embedding ( inputIds - vocabOffset)
1444
1444
embNorm = hardEmbeddingNorm ( hardEmb)
1445
1445
} else {
@@ -1490,7 +1490,7 @@ private func gemma3nAttentionWithCacheUpdate(
1490
1490
// Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
1491
1491
let ( cachedKeys, cachedValues) : ( MLXArray , MLXArray )
1492
1492
1493
- if let cache = cache {
1493
+ if let cache {
1494
1494
( cachedKeys, cachedValues) = cache. update ( keys: keys, values: values)
1495
1495
} else {
1496
1496
( cachedKeys, cachedValues) = ( keys, values)
@@ -1667,7 +1667,6 @@ private func maskedScatter(
1667
1667
private func checkArrayShape( _ arr: MLXArray ) -> Bool {
1668
1668
let shape = arr. shape
1669
1669
guard shape. count == 4 else {
1670
- print ( " 🔍 checkArrayShape: Array has \( shape. count) dimensions, not 4 " )
1671
1670
return false
1672
1671
}
1673
1672
@@ -1792,7 +1791,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1792
1791
}
1793
1792
1794
1793
// Process audio features
1795
- if let inputFeatures = inputFeatures , let inputFeaturesMask = inputFeaturesMask {
1794
+ if let inputFeatures, let inputFeaturesMask = inputFeaturesMask {
1796
1795
let ( audioFeatures, audioMask) = getAudioFeatures ( inputFeatures, .! inputFeaturesMask)
1797
1796
let audioPaddingIds = MLXArray ( [ config. vocabSize - 1 ] ) . expandedDimensions ( axis: 0 )
1798
1797
let audioPaddingEmbs = embedAudio. callAsFunction ( audioPaddingIds, inputsEmbeds: nil )
@@ -1862,7 +1861,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1862
1861
) -> MLXArray {
1863
1862
let specialModalityMask : MLXArray
1864
1863
1865
- if let inputIds = inputIds {
1864
+ if let inputIds {
1866
1865
specialModalityMask = expandedDimensions ( inputIds .== tokenId, axis: - 1 )
1867
1866
} else {
1868
1867
// When inputIds is nil, create mask by comparing embeddings
@@ -1924,10 +1923,9 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1924
1923
1925
1924
// In class Gemma3n
1926
1925
public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
1927
- print ( " 🔍 Gemma3n.sanitize: Starting with \( weights. count) weights " )
1928
1926
var sanitizedWeights = [ String: MLXArray] ( )
1929
1927
1930
- // This function's ONLY job is to remove the "model." prefix from keys.
1928
+ // Remove the "model." prefix from keys.
1931
1929
for (k, v) in weights {
1932
1930
if k. hasPrefix ( " model. " ) {
1933
1931
let newKey = k. split ( separator: " . " ) . dropFirst ( ) . joined ( separator: " . " )
@@ -1937,13 +1935,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1937
1935
}
1938
1936
}
1939
1937
1940
- print ( " 🔍 Gemma3n.sanitize: After prefix removal, have \( sanitizedWeights. count) weights " )
1941
1938
return sanitizedWeights
1942
1939
}
1943
1940
1944
1941
public static func fromPretrained( pathOrHfRepo: String ) throws -> Gemma3n {
1945
1942
let path = URL ( fileURLWithPath: pathOrHfRepo)
1946
- print ( " 🔍 Gemma3n.fromPretrained: Loading from \( pathOrHfRepo) " )
1947
1943
1948
1944
let configPath = path. appendingPathComponent ( " config.json " )
1949
1945
let configData = try Data ( contentsOf: configPath)
@@ -1968,30 +1964,25 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1968
1964
let fileWeights = try loadArrays ( url: path. appendingPathComponent ( weightFile) )
1969
1965
weights. merge ( fileWeights) { _, new in new }
1970
1966
}
1971
- print ( " 🔍 Gemma3n.fromPretrained: Total weights loaded: \( weights. count) " )
1972
1967
1973
- // Step 1: Main sanitization (remove "model." prefix)
1968
+ // Main sanitization (remove "model." prefix)
1974
1969
var sanitizedWeights = model. sanitize ( weights: weights)
1975
1970
1976
- // Step 2: Vision model sanitization (transpose conv weights)
1971
+ // Vision model sanitization (transpose conv weights)
1977
1972
sanitizedWeights = Gemma3nVisionModel . sanitizeWeights ( sanitizedWeights)
1978
1973
1979
- // Step 3: Audio model sanitization (transpose conv weights) - THIS WAS MISSING
1974
+ // Audio model sanitization (transpose conv weights)
1980
1975
sanitizedWeights = model. audioTower. sanitize ( weights: sanitizedWeights)
1981
1976
1982
- // Step 4: Handle tied lm_head weights
1977
+ // Handle tied lm_head weights
1983
1978
if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
1984
1979
if let embedWeight = sanitizedWeights [ " language_model.model.embed_tokens.weight " ] {
1985
- print ( " 🔍 Tying lm_head weight. " )
1986
1980
sanitizedWeights [ " language_model.lm_head.weight " ] = embedWeight
1987
1981
}
1988
1982
}
1989
1983
1990
- // Step 5: Load the weights
1991
- print ( " 🔍 Attempting to load \( sanitizedWeights. count) final weights... " )
1984
+ // Load the weights
1992
1985
try model. update ( parameters: ModuleParameters . unflattened ( sanitizedWeights) , verify: [ . all] )
1993
- print ( " ✅ Model loaded successfully! " )
1994
-
1995
1986
return model
1996
1987
}
1997
1988
}
@@ -2211,7 +2202,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
2211
2202
let expectedInputSuffix = featureDims + [ numChannels]
2212
2203
assert ( Array ( x. shape. suffix ( expectedInputSuffix. count) ) == expectedInputSuffix)
2213
2204
2214
- if let mask = mask {
2205
+ if let mask {
2215
2206
assert ( mask. shape == Array ( x. shape. prefix ( 2 ) ) )
2216
2207
assert ( mask. dtype == . bool)
2217
2208
}
@@ -2221,7 +2212,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
2221
2212
let xCalc = x. asType ( calcDtype)
2222
2213
2223
2214
let maskCalc : MLXArray
2224
- if let mask = mask {
2215
+ if let mask {
2225
2216
let maskSuffixShape = Array ( repeating: 1 , count: expectedInputSuffix. count)
2226
2217
maskCalc = mask. reshaped ( Array ( mask. shape) + maskSuffixShape) . asType ( calcDtype)
2227
2218
} else {
@@ -2848,7 +2839,7 @@ private func rmsNorm2d(
2848
2839
let vMean = mean ( v, axis: 1 , keepDims: true )
2849
2840
var result = x * rsqrt( vMean + eps)
2850
2841
2851
- if let weight = weight {
2842
+ if let weight {
2852
2843
let weightReshaped = weight. reshaped ( [ 1 , - 1 , 1 , 1 ] )
2853
2844
result = result. asType ( dtype) * weightReshaped
2854
2845
}
@@ -3061,7 +3052,7 @@ private class UniversalInvertedResidual: Module, UnaryLayer {
3061
3052
)
3062
3053
3063
3054
// Layer Scale
3064
- if let layerScaleInitValue = layerScaleInitValue {
3055
+ if let layerScaleInitValue {
3065
3056
self . _layerScale. wrappedValue = LayerScale2d (
3066
3057
dim: outChannels, initValues: layerScaleInitValue)
3067
3058
} else {
@@ -3420,7 +3411,7 @@ private class MobileAttention: Module, UnaryLayer {
3420
3411
}
3421
3412
3422
3413
// Layer scaling
3423
- if let layerScaleInitValue = layerScaleInitValue {
3414
+ if let layerScaleInitValue {
3424
3415
self . _layerScale. wrappedValue = LayerScale2d (
3425
3416
dim: outChannels, initValues: layerScaleInitValue)
3426
3417
} else {
@@ -3843,7 +3834,6 @@ private class Gemma3nVisionModel: Module {
3843
3834
sanitizedWeights [ k] = v
3844
3835
}
3845
3836
} else {
3846
- // THIS IS THE MISSING BLOCK
3847
3837
// Copy all other weights (biases, norm layers, etc.)
3848
3838
sanitizedWeights [ k] = v
3849
3839
}
@@ -3955,7 +3945,7 @@ private class Gemma3nAudioModel: Module {
3955
3945
for (k, v) in weights {
3956
3946
if k. contains ( " conv.weight " ) {
3957
3947
// The checkArrayShape function is not robust.
3958
- // The Python reference doesn't use it. It's safer to just transpose.
3948
+ // The Python implementation doesn't use it. It's safer to just transpose.
3959
3949
// Assuming NCHW -> NHWC for Conv2d
3960
3950
if v. ndim == 4 {
3961
3951
sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
@@ -3970,7 +3960,6 @@ private class Gemma3nAudioModel: Module {
3970
3960
sanitizedWeights [ k] = v
3971
3961
}
3972
3962
} else {
3973
- // THIS IS THE MISSING BLOCK
3974
3963
sanitizedWeights [ k] = v
3975
3964
}
3976
3965
}
@@ -4175,8 +4164,8 @@ public struct Gemma3nProcessorConfiguration: Codable, Sendable {
4175
4164
public let doPanAndScan : Bool ?
4176
4165
4177
4166
// Token identifiers - use default values that match Python implementation
4178
- public var imageTokenId : Int { 262145 } // From Python: image_token_id = 262145
4179
- public var audioTokenId : Int { 262273 } // From Python: audio_token_id = 262273
4167
+ public var imageTokenId : Int { 262145 }
4168
+ public var audioTokenId : Int { 262273 }
4180
4169
4181
4170
public struct ImageSize : Codable , Sendable {
4182
4171
public let height : Int
0 commit comments