Skip to content

Commit 8e1d7d5

Browse files
committed
Clean up
1 parent 0d6d026 commit 8e1d7d5

File tree

1 file changed

+24
-35
lines changed

1 file changed

+24
-35
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,18 +1146,18 @@ private class Gemma3Model: Module {
11461146
perLayerInputs: MLXArray? = nil
11471147
) -> MLXArray {
11481148
var h: MLXArray
1149-
if let inputsEmbeds = inputsEmbeds {
1149+
if let inputsEmbeds {
11501150
h = inputsEmbeds
1151-
} else if let inputs = inputs {
1151+
} else if let inputs {
11521152
h = embedTokens(inputs)
11531153
} else {
11541154
fatalError("Either inputs or inputsEmbeds must be provided")
11551155
}
11561156

11571157
let perLayerInputsProcessed: MLXArray
1158-
if let perLayerInputs = perLayerInputs {
1158+
if let perLayerInputs {
11591159
perLayerInputsProcessed = perLayerInputs
1160-
} else if let inputs = inputs {
1160+
} else if let inputs {
11611161
perLayerInputsProcessed = getPerLayerInputs(inputs)
11621162
} else {
11631163
fatalError("Cannot generate per layer inputs without input ids")
@@ -1213,7 +1213,7 @@ private class Gemma3Model: Module {
12131213
== "global_attention"
12141214

12151215
let localMask: MLXFast.ScaledDotProductAttentionMaskMode
1216-
if let mask = mask {
1216+
if let mask {
12171217
localMask = mask
12181218
} else if isGlobal {
12191219
localMask = fullMask
@@ -1437,9 +1437,9 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
14371437
}
14381438

14391439
let embNorm: MLXArray
1440-
if let inputsEmbeds = inputsEmbeds {
1440+
if let inputsEmbeds {
14411441
embNorm = softEmbeddingNorm(inputsEmbeds)
1442-
} else if let inputIds = inputIds {
1442+
} else if let inputIds {
14431443
let hardEmb = embedding(inputIds - vocabOffset)
14441444
embNorm = hardEmbeddingNorm(hardEmb)
14451445
} else {
@@ -1490,7 +1490,7 @@ private func gemma3nAttentionWithCacheUpdate(
14901490
// Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
14911491
let (cachedKeys, cachedValues): (MLXArray, MLXArray)
14921492

1493-
if let cache = cache {
1493+
if let cache {
14941494
(cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
14951495
} else {
14961496
(cachedKeys, cachedValues) = (keys, values)
@@ -1667,7 +1667,6 @@ private func maskedScatter(
16671667
private func checkArrayShape(_ arr: MLXArray) -> Bool {
16681668
let shape = arr.shape
16691669
guard shape.count == 4 else {
1670-
print("🔍 checkArrayShape: Array has \(shape.count) dimensions, not 4")
16711670
return false
16721671
}
16731672

@@ -1792,7 +1791,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
17921791
}
17931792

17941793
// Process audio features
1795-
if let inputFeatures = inputFeatures, let inputFeaturesMask = inputFeaturesMask {
1794+
if let inputFeatures, let inputFeaturesMask = inputFeaturesMask {
17961795
let (audioFeatures, audioMask) = getAudioFeatures(inputFeatures, .!inputFeaturesMask)
17971796
let audioPaddingIds = MLXArray([config.vocabSize - 1]).expandedDimensions(axis: 0)
17981797
let audioPaddingEmbs = embedAudio.callAsFunction(audioPaddingIds, inputsEmbeds: nil)
@@ -1862,7 +1861,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18621861
) -> MLXArray {
18631862
let specialModalityMask: MLXArray
18641863

1865-
if let inputIds = inputIds {
1864+
if let inputIds {
18661865
specialModalityMask = expandedDimensions(inputIds .== tokenId, axis: -1)
18671866
} else {
18681867
// When inputIds is nil, create mask by comparing embeddings
@@ -1924,10 +1923,9 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19241923

19251924
// In class Gemma3n
19261925
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
1927-
print("🔍 Gemma3n.sanitize: Starting with \(weights.count) weights")
19281926
var sanitizedWeights = [String: MLXArray]()
19291927

1930-
// This function's ONLY job is to remove the "model." prefix from keys.
1928+
// Remove the "model." prefix from keys.
19311929
for (k, v) in weights {
19321930
if k.hasPrefix("model.") {
19331931
let newKey = k.split(separator: ".").dropFirst().joined(separator: ".")
@@ -1937,13 +1935,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19371935
}
19381936
}
19391937

1940-
print("🔍 Gemma3n.sanitize: After prefix removal, have \(sanitizedWeights.count) weights")
19411938
return sanitizedWeights
19421939
}
19431940

19441941
public static func fromPretrained(pathOrHfRepo: String) throws -> Gemma3n {
19451942
let path = URL(fileURLWithPath: pathOrHfRepo)
1946-
print("🔍 Gemma3n.fromPretrained: Loading from \(pathOrHfRepo)")
19471943

19481944
let configPath = path.appendingPathComponent("config.json")
19491945
let configData = try Data(contentsOf: configPath)
@@ -1968,30 +1964,25 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19681964
let fileWeights = try loadArrays(url: path.appendingPathComponent(weightFile))
19691965
weights.merge(fileWeights) { _, new in new }
19701966
}
1971-
print("🔍 Gemma3n.fromPretrained: Total weights loaded: \(weights.count)")
19721967

1973-
// Step 1: Main sanitization (remove "model." prefix)
1968+
// Main sanitization (remove "model." prefix)
19741969
var sanitizedWeights = model.sanitize(weights: weights)
19751970

1976-
// Step 2: Vision model sanitization (transpose conv weights)
1971+
// Vision model sanitization (transpose conv weights)
19771972
sanitizedWeights = Gemma3nVisionModel.sanitizeWeights(sanitizedWeights)
19781973

1979-
// Step 3: Audio model sanitization (transpose conv weights) - THIS WAS MISSING
1974+
// Audio model sanitization (transpose conv weights)
19801975
sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
19811976

1982-
// Step 4: Handle tied lm_head weights
1977+
// Handle tied lm_head weights
19831978
if sanitizedWeights["language_model.lm_head.weight"] == nil {
19841979
if let embedWeight = sanitizedWeights["language_model.model.embed_tokens.weight"] {
1985-
print("🔍 Tying lm_head weight.")
19861980
sanitizedWeights["language_model.lm_head.weight"] = embedWeight
19871981
}
19881982
}
19891983

1990-
// Step 5: Load the weights
1991-
print("🔍 Attempting to load \(sanitizedWeights.count) final weights...")
1984+
// Load the weights
19921985
try model.update(parameters: ModuleParameters.unflattened(sanitizedWeights), verify: [.all])
1993-
print("✅ Model loaded successfully!")
1994-
19951986
return model
19961987
}
19971988
}
@@ -2211,7 +2202,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
22112202
let expectedInputSuffix = featureDims + [numChannels]
22122203
assert(Array(x.shape.suffix(expectedInputSuffix.count)) == expectedInputSuffix)
22132204

2214-
if let mask = mask {
2205+
if let mask {
22152206
assert(mask.shape == Array(x.shape.prefix(2)))
22162207
assert(mask.dtype == .bool)
22172208
}
@@ -2221,7 +2212,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
22212212
let xCalc = x.asType(calcDtype)
22222213

22232214
let maskCalc: MLXArray
2224-
if let mask = mask {
2215+
if let mask {
22252216
let maskSuffixShape = Array(repeating: 1, count: expectedInputSuffix.count)
22262217
maskCalc = mask.reshaped(Array(mask.shape) + maskSuffixShape).asType(calcDtype)
22272218
} else {
@@ -2848,7 +2839,7 @@ private func rmsNorm2d(
28482839
let vMean = mean(v, axis: 1, keepDims: true)
28492840
var result = x * rsqrt(vMean + eps)
28502841

2851-
if let weight = weight {
2842+
if let weight {
28522843
let weightReshaped = weight.reshaped([1, -1, 1, 1])
28532844
result = result.asType(dtype) * weightReshaped
28542845
}
@@ -3061,7 +3052,7 @@ private class UniversalInvertedResidual: Module, UnaryLayer {
30613052
)
30623053

30633054
// Layer Scale
3064-
if let layerScaleInitValue = layerScaleInitValue {
3055+
if let layerScaleInitValue {
30653056
self._layerScale.wrappedValue = LayerScale2d(
30663057
dim: outChannels, initValues: layerScaleInitValue)
30673058
} else {
@@ -3420,7 +3411,7 @@ private class MobileAttention: Module, UnaryLayer {
34203411
}
34213412

34223413
// Layer scaling
3423-
if let layerScaleInitValue = layerScaleInitValue {
3414+
if let layerScaleInitValue {
34243415
self._layerScale.wrappedValue = LayerScale2d(
34253416
dim: outChannels, initValues: layerScaleInitValue)
34263417
} else {
@@ -3843,7 +3834,6 @@ private class Gemma3nVisionModel: Module {
38433834
sanitizedWeights[k] = v
38443835
}
38453836
} else {
3846-
// THIS IS THE MISSING BLOCK
38473837
// Copy all other weights (biases, norm layers, etc.)
38483838
sanitizedWeights[k] = v
38493839
}
@@ -3955,7 +3945,7 @@ private class Gemma3nAudioModel: Module {
39553945
for (k, v) in weights {
39563946
if k.contains("conv.weight") {
39573947
// The checkArrayShape function is not robust.
3958-
// The Python reference doesn't use it. It's safer to just transpose.
3948+
// The Python implementation doesn't use it. It's safer to just transpose.
39593949
// Assuming NCHW -> NHWC for Conv2d
39603950
if v.ndim == 4 {
39613951
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
@@ -3970,7 +3960,6 @@ private class Gemma3nAudioModel: Module {
39703960
sanitizedWeights[k] = v
39713961
}
39723962
} else {
3973-
// THIS IS THE MISSING BLOCK
39743963
sanitizedWeights[k] = v
39753964
}
39763965
}
@@ -4175,8 +4164,8 @@ public struct Gemma3nProcessorConfiguration: Codable, Sendable {
41754164
public let doPanAndScan: Bool?
41764165

41774166
// Token identifiers - use default values that match Python implementation
4178-
public var imageTokenId: Int { 262145 } // From Python: image_token_id = 262145
4179-
public var audioTokenId: Int { 262273 } // From Python: audio_token_id = 262273
4167+
public var imageTokenId: Int { 262145 }
4168+
public var audioTokenId: Int { 262273 }
41804169

41814170
public struct ImageSize: Codable, Sendable {
41824171
public let height: Int

0 commit comments

Comments
 (0)