Skip to content

Commit 8820938

Browse files
committed
Fix sanitization, computed layers, module keys
1 parent 79df57c commit 8820938

File tree

1 file changed

+64
-93
lines changed

1 file changed

+64
-93
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 64 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,33 +1340,27 @@ private class LanguageModel: Module, KVCacheDimensionProvider {
13401340
}
13411341

13421342
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
1343-
var sanitizedWeights = [String: MLXArray]()
1344-
1343+
var sanitizedWeights = weights
13451344
for (k, v) in weights {
1346-
// Skip rotary embedding inverse frequency weights (matches Python exactly)
1347-
if k.contains("self_attn.rotary_emb.inv_freq") {
1348-
continue
1349-
}
1350-
// Python logic: if "language_model.model" not in k and "language_model.lm_head" not in k:
1351-
else if !k.contains("language_model.model") && !k.contains("language_model.lm_head") {
1345+
if !k.contains("language_model.model") && !k.contains("language_model.lm_head") {
1346+
// Transform keys that don't contain the specific patterns
13521347
let newKey = k.replacingOccurrences(
13531348
of: "language_model", with: "language_model.model")
13541349
sanitizedWeights[newKey] = v
1355-
}
1356-
// Otherwise, keep the key as is
1357-
else {
1350+
} else if k.contains("self_attn.rotary_emb.inv_freq") {
1351+
// Skip rotary embedding inverse frequency weights
1352+
continue
1353+
} else {
13581354
sanitizedWeights[k] = v
13591355
}
13601356
}
1361-
1362-
// If lm_head weight is missing, use embed_tokens weight as fallback (matches Python exactly)
1357+
// Handle tied lm_head weights
13631358
if sanitizedWeights["language_model.lm_head.weight"] == nil {
13641359
let embedTokensKey = "language_model.model.embed_tokens.weight"
13651360
if let embedWeight = sanitizedWeights[embedTokensKey] {
13661361
sanitizedWeights["language_model.lm_head.weight"] = embedWeight
13671362
}
13681363
}
1369-
13701364
return sanitizedWeights
13711365
}
13721366
}
@@ -1676,7 +1670,6 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
16761670
self._languageModel.wrappedValue = LanguageModel(config: config.textConfig)
16771671
self._visionTower.wrappedValue = Gemma3nVisionModel(config: config.visionConfig)
16781672
self._audioTower.wrappedValue = Gemma3nAudioModel(config: config.audioConfig)
1679-
16801673
self._embedVision.wrappedValue = Gemma3nMultimodalEmbedder(
16811674
multimodalConfig: config.visionConfig,
16821675
textConfig: config.textConfig
@@ -1893,20 +1886,16 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18931886
return languageModel(inputs: inputs, cache: convertedCache).logits
18941887
}
18951888

1896-
// In class Gemma3n
18971889
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
18981890
var sanitizedWeights = [String: MLXArray]()
1899-
1900-
// Remove the "model." prefix from keys.
19011891
for (k, v) in weights {
1902-
if k.hasPrefix("model.") {
1892+
if k.starts(with: "model.") {
19031893
let newKey = k.split(separator: ".").dropFirst().joined(separator: ".")
19041894
sanitizedWeights[newKey] = v
19051895
} else {
19061896
sanitizedWeights[k] = v
19071897
}
19081898
}
1909-
19101899
return sanitizedWeights
19111900
}
19121901

@@ -1937,14 +1926,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19371926
weights.merge(fileWeights) { _, new in new }
19381927
}
19391928

1940-
// Main sanitization (remove "model." prefix)
19411929
var sanitizedWeights = model.sanitize(weights: weights)
1942-
1943-
// Vision model sanitization (transpose conv weights)
1944-
sanitizedWeights = Gemma3nVisionModel.sanitizeWeights(sanitizedWeights)
1945-
1946-
// Audio model sanitization (transpose conv weights)
1947-
sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1930+
sanitizedWeights = model.visionTower.sanitize(weights: sanitizedWeights)
1931+
// The audio and language sanitization is not done in the Python implementation
1932+
// sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933+
// sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
19481934

19491935
// Handle tied lm_head weights
19501936
if sanitizedWeights["language_model.lm_head.weight"] == nil {
@@ -1992,7 +1978,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
19921978
let maxForward: Int
19931979

19941980
@ModuleInfo(key: "pos_proj") var posProj: Linear
1995-
@ModuleInfo(key: "inv_timescales") var invTimescales: MLXArray
1981+
private let _invTimescales: MLXArray
19961982

19971983
init(config: AudioConfig) {
19981984
self.config = config
@@ -2016,7 +2002,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
20162002
MLXArray(0 ..< numTimescales).asType(.float32) * (-logTimescaleIncrement)
20172003
)
20182004

2019-
self._invTimescales.wrappedValue = expandedDimensions(
2005+
self._invTimescales = expandedDimensions(
20202006
expandedDimensions(invTimescales, axis: 0),
20212007
axis: 0
20222008
)
@@ -2028,7 +2014,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
20282014
assert(position.ndim == 2)
20292015
let positionFloat = expandedDimensions(position.asType(.float32), axis: -1)
20302016

2031-
let scaledTime = positionFloat * invTimescales
2017+
let scaledTime = positionFloat * _invTimescales
20322018
let timingSignal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1)
20332019
return timingSignal.asType(dtype)
20342020
}
@@ -2328,6 +2314,7 @@ private class Gemma3nAudioSubSampleConvProjection: Module {
23282314

23292315
let fInPadded = currentFForBlockInput + padFLeft + padFRight
23302316
let fOutAfterConv = (fInPadded - kernelW) / strideW + 1
2317+
23312318
calculatedFOutDims.append(fOutAfterConv)
23322319
currentFForBlockInput = fOutAfterConv
23332320
}
@@ -2389,8 +2376,8 @@ private class Gemma3nAudioAttention: Module {
23892376
let attentionLogitsSoftCap: Float
23902377
let contextSize: Int
23912378
let qScale: Float
2392-
let localCausalValidMask: MLXArray
2393-
let softcap: MLXArray
2379+
private let _localCausalValidMask: MLXArray
2380+
private let _softcap: MLXArray
23942381

23952382
@ModuleInfo(key: "relative_position_embedding") var relativePositionEmbedding:
23962383
Gemma3nAudioRelativePositionEmbedding
@@ -2434,9 +2421,10 @@ private class Gemma3nAudioAttention: Module {
24342421
)
24352422

24362423
let localCausalValidMaskTemp = MLXArray.ones([chunkSize, contextSize], dtype: .bool)
2437-
self.localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask .&& upperCausalMask
2424+
self._localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask
2425+
.&& upperCausalMask
24382426

2439-
self.softcap = MLXArray(attentionLogitsSoftCap, dtype: .float32)
2427+
self._softcap = MLXArray(attentionLogitsSoftCap, dtype: .float32)
24402428

24412429
super.init()
24422430
}
@@ -2536,7 +2524,7 @@ private class Gemma3nAudioAttention: Module {
25362524

25372525
let conditionFromCausality = expandedDimensions(
25382526
expandedDimensions(
2539-
expandedDimensions(localCausalValidMask, axis: 0),
2527+
expandedDimensions(_localCausalValidMask, axis: 0),
25402528
axis: 0
25412529
),
25422530
axis: 0
@@ -2547,9 +2535,9 @@ private class Gemma3nAudioAttention: Module {
25472535
var logits = relativePositionEmbedding(queryBlocks, keyBlocks)
25482536

25492537
// Apply attention logit softcap
2550-
logits = logits / softcap
2538+
logits = logits / _softcap
25512539
logits = tanh(logits)
2552-
logits = logits * softcap
2540+
logits = logits * _softcap
25532541

25542542
// Apply the combined mask
25552543
logits = MLX.where(
@@ -2591,10 +2579,10 @@ private class Gemma3nAudioConformerAttention: Module {
25912579
let postInFeatures: Int
25922580
private let _gradientClipping: MLXArray
25932581

2594-
@ModuleInfo var preAttnNorm: Gemma3nRMSNormWithScale
2582+
@ModuleInfo(key: "pre_attn_norm") var preAttnNorm: Gemma3nRMSNormWithScale
25952583
@ModuleInfo var attn: Gemma3nAudioAttention
25962584
@ModuleInfo var post: Linear
2597-
@ModuleInfo var postNorm: Gemma3nRMSNormWithScale
2585+
@ModuleInfo(key: "post_norm") var postNorm: Gemma3nRMSNormWithScale
25982586

25992587
init(config: AudioConfig) {
26002588
self.config = config
@@ -2737,17 +2725,17 @@ private class Gemma3nAudioConformerLightConv1d: Module {
27372725
// MARK: - Conformer Block
27382726
private class Gemma3nAudioConformerBlock: Module {
27392727
let config: AudioConfig
2740-
private let gradientClipping: MLXArray
2728+
private let _gradientClipping: MLXArray
27412729

2742-
@ModuleInfo var ffwLayerStart: Gemma3nAudioConformerFeedForward
2730+
@ModuleInfo(key: "ffw_layer_start") var ffwLayerStart: Gemma3nAudioConformerFeedForward
27432731
@ModuleInfo var attention: Gemma3nAudioConformerAttention
27442732
@ModuleInfo var lconv1d: Gemma3nAudioConformerLightConv1d
2745-
@ModuleInfo var ffwLayerEnd: Gemma3nAudioConformerFeedForward
2733+
@ModuleInfo(key: "ffw_layer_end") var ffwLayerEnd: Gemma3nAudioConformerFeedForward
27462734
@ModuleInfo var norm: Gemma3nRMSNormWithScale
27472735

27482736
init(config: AudioConfig) {
27492737
self.config = config
2750-
self.gradientClipping = MLXArray(config.gradientClipping)
2738+
self._gradientClipping = MLXArray(config.gradientClipping)
27512739

27522740
self._ffwLayerStart.wrappedValue = Gemma3nAudioConformerFeedForward(config: config)
27532741
self._attention.wrappedValue = Gemma3nAudioConformerAttention(config: config)
@@ -2771,7 +2759,7 @@ private class Gemma3nAudioConformerBlock: Module {
27712759

27722760
result = lconv1d(audioencodingsForLconvInput)
27732761
result = ffwLayerEnd(result)
2774-
result = clip(result, min: -gradientClipping, max: gradientClipping)
2762+
result = clip(result, min: -_gradientClipping, max: _gradientClipping)
27752763
return norm(result)
27762764
}
27772765
}
@@ -2856,7 +2844,8 @@ private func numGroups(groupSize: Int?, channels: Int) -> Int {
28562844
}
28572845
// NOTE: groupSize == 1 -> depthwise conv
28582846
assert(channels % groupSize == 0)
2859-
return channels / groupSize
2847+
let groups = channels / groupSize
2848+
return groups
28602849
}
28612850

28622851
private func makeDivisible(
@@ -3082,6 +3071,7 @@ private class EdgeResidual: Module, UnaryLayer {
30823071
self.hasSkip = (inChannels == outChannels && stride == 1) && !noskip
30833072

30843073
let padding = (expKernelSize - 1) / 2
3074+
30853075
self._convExp.wrappedValue = Conv2d(
30863076
inputChannels: inChannels,
30873077
outputChannels: midChannels,
@@ -3139,17 +3129,17 @@ private class MultiQueryAttention2d: Module {
31393129
let valueDim: Int
31403130
let scale: Float
31413131

3142-
@ModuleInfo var queryProj: Conv2d
3132+
@ModuleInfo(key: "query_proj") var queryProj: Conv2d
31433133

3144-
@ModuleInfo var keyDownConv: UnaryLayer
3145-
@ModuleInfo var keyNorm: UnaryLayer
3146-
@ModuleInfo var valueDownConv: UnaryLayer
3147-
@ModuleInfo var valueNorm: UnaryLayer
3134+
@ModuleInfo(key: "key_down_conv") var keyDownConv: UnaryLayer
3135+
@ModuleInfo(key: "key_norm") var keyNorm: UnaryLayer
3136+
@ModuleInfo(key: "value_down_conv") var valueDownConv: UnaryLayer
3137+
@ModuleInfo(key: "value_norm") var valueNorm: UnaryLayer
31483138

3149-
@ModuleInfo var keyProj: Conv2d
3150-
@ModuleInfo var valueProj: Conv2d
3139+
@ModuleInfo(key: "key_proj") var keyProj: Conv2d
3140+
@ModuleInfo(key: "value_proj") var valueProj: Conv2d
31513141
@ModuleInfo(key: "attn_drop") var attnDrop: UnaryLayer
3152-
@ModuleInfo var outputProj: Conv2d
3142+
@ModuleInfo(key: "output_proj") var outputProj: Conv2d
31533143
@ModuleInfo(key: "proj_drop") var projDrop: UnaryLayer
31543144

31553145
init(
@@ -3195,6 +3185,7 @@ private class MultiQueryAttention2d: Module {
31953185
groups: dim, // Depthwise
31963186
bias: false
31973187
)
3188+
31983189
self._keyNorm.wrappedValue = RMSNormAct2d(numChannels: dim, eps: 1e-6, applyAct: false)
31993190
} else {
32003191
self._keyDownConv.wrappedValue = Identity()
@@ -3323,8 +3314,8 @@ private class MobileAttention: Module, UnaryLayer {
33233314

33243315
@ModuleInfo var norm: RMSNormAct2d
33253316
@ModuleInfo var attn: MultiQueryAttention2d
3326-
@ModuleInfo var layerScale: UnaryLayer
3327-
@ModuleInfo var dropPath: Identity
3317+
@ModuleInfo(key: "layer_scale") var layerScale: UnaryLayer
3318+
@ModuleInfo(key: "drop_path") var dropPath: Identity
33283319

33293320
init(
33303321
inChannels: Int,
@@ -3544,7 +3535,7 @@ private class MobileNetV5MultiScaleFusionAdapter: Module {
35443535

35453536
@ModuleInfo var ffn: UniversalInvertedResidual
35463537
@ModuleInfo var norm: RMSNormAct2d
3547-
@ModuleInfo var avgPool: AvgPool2d
3538+
@ModuleInfo(key: "avg_pool") var avgPool: AvgPool2d
35483539

35493540
init(
35503541
inChannels: [Int],
@@ -3780,37 +3771,23 @@ private class Gemma3nVisionModel: Module {
37803771
}
37813772

37823773
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
3783-
return Self.sanitizeWeights(weights)
3784-
}
3785-
3786-
static func sanitizeWeights(_ weights: [String: MLXArray]) -> [String: MLXArray] {
3787-
var sanitizedWeights = [String: MLXArray]()
3774+
var sanitizedWeights = weights
37883775
var skipTranspose = false
3789-
3790-
// This logic is correct
37913776
let testKey = "vision_tower.timm_model.blocks.0.0.conv_exp.weight"
3792-
if let convWeight = weights[testKey] {
3793-
let shape = convWeight.shape
3794-
if shape.count == 4, shape[3] > shape[1] {
3795-
skipTranspose = true
3796-
}
3777+
if let convWeight = weights[testKey], convWeight.ndim == 4,
3778+
convWeight.shape[3] > convWeight.shape[1]
3779+
{
3780+
skipTranspose = true
37973781
}
3798-
37993782
for (k, v) in weights {
38003783
if (k.contains("conv") && k.contains("weight"))
38013784
|| (k.contains("attn") && k.contains("proj.weight"))
38023785
{
3803-
if v.shape.count == 4 && !skipTranspose {
3786+
if v.ndim == 4 && !skipTranspose {
38043787
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
3805-
} else {
3806-
sanitizedWeights[k] = v
38073788
}
3808-
} else {
3809-
// Copy all other weights (biases, norm layers, etc.)
3810-
sanitizedWeights[k] = v
38113789
}
38123790
}
3813-
38143791
return sanitizedWeights
38153792
}
38163793
}
@@ -3828,8 +3805,9 @@ private class Gemma3nAudioModel: Module {
38283805

38293806
self._subsampleConvProjection.wrappedValue = Gemma3nAudioSubSampleConvProjection(
38303807
config: config)
3831-
self._conformer.wrappedValue = (0 ..< config.confNumHiddenLayers).map { _ in
3832-
Gemma3nAudioConformerBlock(config: config)
3808+
3809+
self._conformer.wrappedValue = (0 ..< config.confNumHiddenLayers).map { i in
3810+
return Gemma3nAudioConformerBlock(config: config)
38333811
}
38343812

38353813
super.init()
@@ -3914,32 +3892,25 @@ private class Gemma3nAudioModel: Module {
39143892
/// Sanitizes weights by transposing convolution layers if they are not
39153893
/// already in the expected MLX format.
39163894
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
3917-
var sanitizedWeights = [String: MLXArray]()
3918-
3895+
var sanitizedWeights = weights
3896+
// Iterate over the original keys to decide which ones to modify in the copy.
39193897
for (k, v) in weights {
39203898
if k.contains("conv.weight") {
3921-
// A Conv2D weight should be 4D.
3922-
// If it is, check if it needs transposing from NCHW to NHWC.
3923-
// If checkArrayShape is true, it's already in the correct format.
3924-
if v.ndim == 4 && !checkArrayShape(v) {
3925-
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
3926-
} else {
3899+
if checkArrayShape(v) {
39273900
sanitizedWeights[k] = v
3901+
} else {
3902+
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
39283903
}
39293904
} else if k.contains("conv1d.weight") {
3930-
// A Conv1D weight should be 3D.
3931-
// If it is, check if it needs transposing from NCL to NLC.
3932-
if v.ndim == 3 && !checkArrayShape(v) {
3933-
sanitizedWeights[k] = v.transposed(0, 2, 1)
3934-
} else {
3905+
if true {
39353906
sanitizedWeights[k] = v
3907+
} else {
3908+
sanitizedWeights[k] = v.transposed(0, 2, 1)
39363909
}
39373910
} else {
3938-
// For all other weights, keep them as they are.
39393911
sanitizedWeights[k] = v
39403912
}
39413913
}
3942-
39433914
return sanitizedWeights
39443915
}
39453916
}

0 commit comments

Comments
 (0)