Skip to content

Commit 3d61c89

Browse files
committed
Fix sanitization and computed layers
1 parent 79df57c commit 3d61c89

File tree

1 file changed

+57
-86
lines changed

1 file changed

+57
-86
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 57 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,33 +1340,27 @@ private class LanguageModel: Module, KVCacheDimensionProvider {
13401340
}
13411341

13421342
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
1343-
var sanitizedWeights = [String: MLXArray]()
1344-
1343+
var sanitizedWeights = weights
13451344
for (k, v) in weights {
1346-
// Skip rotary embedding inverse frequency weights (matches Python exactly)
1347-
if k.contains("self_attn.rotary_emb.inv_freq") {
1348-
continue
1349-
}
1350-
// Python logic: if "language_model.model" not in k and "language_model.lm_head" not in k:
1351-
else if !k.contains("language_model.model") && !k.contains("language_model.lm_head") {
1345+
if !k.contains("language_model.model") && !k.contains("language_model.lm_head") {
1346+
// Transform keys that don't contain the specific patterns
13521347
let newKey = k.replacingOccurrences(
13531348
of: "language_model", with: "language_model.model")
13541349
sanitizedWeights[newKey] = v
1355-
}
1356-
// Otherwise, keep the key as is
1357-
else {
1350+
} else if k.contains("self_attn.rotary_emb.inv_freq") {
1351+
// Skip rotary embedding inverse frequency weights
1352+
continue
1353+
} else {
13581354
sanitizedWeights[k] = v
13591355
}
13601356
}
1361-
1362-
// If lm_head weight is missing, use embed_tokens weight as fallback (matches Python exactly)
1357+
// Handle tied lm_head weights
13631358
if sanitizedWeights["language_model.lm_head.weight"] == nil {
13641359
let embedTokensKey = "language_model.model.embed_tokens.weight"
13651360
if let embedWeight = sanitizedWeights[embedTokensKey] {
13661361
sanitizedWeights["language_model.lm_head.weight"] = embedWeight
13671362
}
13681363
}
1369-
13701364
return sanitizedWeights
13711365
}
13721366
}
@@ -1676,7 +1670,6 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
16761670
self._languageModel.wrappedValue = LanguageModel(config: config.textConfig)
16771671
self._visionTower.wrappedValue = Gemma3nVisionModel(config: config.visionConfig)
16781672
self._audioTower.wrappedValue = Gemma3nAudioModel(config: config.audioConfig)
1679-
16801673
self._embedVision.wrappedValue = Gemma3nMultimodalEmbedder(
16811674
multimodalConfig: config.visionConfig,
16821675
textConfig: config.textConfig
@@ -1893,20 +1886,16 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18931886
return languageModel(inputs: inputs, cache: convertedCache).logits
18941887
}
18951888

1896-
// In class Gemma3n
18971889
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
18981890
var sanitizedWeights = [String: MLXArray]()
1899-
1900-
// Remove the "model." prefix from keys.
19011891
for (k, v) in weights {
1902-
if k.hasPrefix("model.") {
1892+
if k.starts(with: "model.") {
19031893
let newKey = k.split(separator: ".").dropFirst().joined(separator: ".")
19041894
sanitizedWeights[newKey] = v
19051895
} else {
19061896
sanitizedWeights[k] = v
19071897
}
19081898
}
1909-
19101899
return sanitizedWeights
19111900
}
19121901

@@ -1937,14 +1926,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19371926
weights.merge(fileWeights) { _, new in new }
19381927
}
19391928

1940-
// Main sanitization (remove "model." prefix)
19411929
var sanitizedWeights = model.sanitize(weights: weights)
1942-
1943-
// Vision model sanitization (transpose conv weights)
1944-
sanitizedWeights = Gemma3nVisionModel.sanitizeWeights(sanitizedWeights)
1945-
1946-
// Audio model sanitization (transpose conv weights)
1947-
sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1930+
sanitizedWeights = model.visionTower.sanitize(weights: sanitizedWeights)
1931+
// The audio and language sanitization is not done in the Python implementation
1932+
// sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933+
// sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
19481934

19491935
// Handle tied lm_head weights
19501936
if sanitizedWeights["language_model.lm_head.weight"] == nil {
@@ -1992,7 +1978,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
19921978
let maxForward: Int
19931979

19941980
@ModuleInfo(key: "pos_proj") var posProj: Linear
1995-
@ModuleInfo(key: "inv_timescales") var invTimescales: MLXArray
1981+
private let _invTimescales: MLXArray
19961982

19971983
init(config: AudioConfig) {
19981984
self.config = config
@@ -2016,7 +2002,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
20162002
MLXArray(0 ..< numTimescales).asType(.float32) * (-logTimescaleIncrement)
20172003
)
20182004

2019-
self._invTimescales.wrappedValue = expandedDimensions(
2005+
self._invTimescales = expandedDimensions(
20202006
expandedDimensions(invTimescales, axis: 0),
20212007
axis: 0
20222008
)
@@ -2028,7 +2014,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
20282014
assert(position.ndim == 2)
20292015
let positionFloat = expandedDimensions(position.asType(.float32), axis: -1)
20302016

2031-
let scaledTime = positionFloat * invTimescales
2017+
let scaledTime = positionFloat * _invTimescales
20322018
let timingSignal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1)
20332019
return timingSignal.asType(dtype)
20342020
}
@@ -2328,6 +2314,7 @@ private class Gemma3nAudioSubSampleConvProjection: Module {
23282314

23292315
let fInPadded = currentFForBlockInput + padFLeft + padFRight
23302316
let fOutAfterConv = (fInPadded - kernelW) / strideW + 1
2317+
23312318
calculatedFOutDims.append(fOutAfterConv)
23322319
currentFForBlockInput = fOutAfterConv
23332320
}
@@ -2389,8 +2376,8 @@ private class Gemma3nAudioAttention: Module {
23892376
let attentionLogitsSoftCap: Float
23902377
let contextSize: Int
23912378
let qScale: Float
2392-
let localCausalValidMask: MLXArray
2393-
let softcap: MLXArray
2379+
private let _localCausalValidMask: MLXArray
2380+
private let _softcap: MLXArray
23942381

23952382
@ModuleInfo(key: "relative_position_embedding") var relativePositionEmbedding:
23962383
Gemma3nAudioRelativePositionEmbedding
@@ -2434,9 +2421,10 @@ private class Gemma3nAudioAttention: Module {
24342421
)
24352422

24362423
let localCausalValidMaskTemp = MLXArray.ones([chunkSize, contextSize], dtype: .bool)
2437-
self.localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask .&& upperCausalMask
2424+
self._localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask
2425+
.&& upperCausalMask
24382426

2439-
self.softcap = MLXArray(attentionLogitsSoftCap, dtype: .float32)
2427+
self._softcap = MLXArray(attentionLogitsSoftCap, dtype: .float32)
24402428

24412429
super.init()
24422430
}
@@ -2536,7 +2524,7 @@ private class Gemma3nAudioAttention: Module {
25362524

25372525
let conditionFromCausality = expandedDimensions(
25382526
expandedDimensions(
2539-
expandedDimensions(localCausalValidMask, axis: 0),
2527+
expandedDimensions(_localCausalValidMask, axis: 0),
25402528
axis: 0
25412529
),
25422530
axis: 0
@@ -2547,9 +2535,9 @@ private class Gemma3nAudioAttention: Module {
25472535
var logits = relativePositionEmbedding(queryBlocks, keyBlocks)
25482536

25492537
// Apply attention logit softcap
2550-
logits = logits / softcap
2538+
logits = logits / _softcap
25512539
logits = tanh(logits)
2552-
logits = logits * softcap
2540+
logits = logits * _softcap
25532541

25542542
// Apply the combined mask
25552543
logits = MLX.where(
@@ -2635,8 +2623,8 @@ private class Gemma3nAudioConformerFeedForward: Module {
26352623
private let _postLayerScale: MLXArray
26362624

26372625
@ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma3nRMSNormWithScale
2638-
@ModuleInfo(key: "ffw_layer_1") var ffwLayer1: Linear
2639-
@ModuleInfo(key: "ffw_layer_2") var ffwLayer2: Linear
2626+
private let _ffwLayer1: Linear
2627+
private let _ffwLayer2: Linear
26402628
@ModuleInfo(key: "post_layer_norm") var postLayerNorm: Gemma3nRMSNormWithScale
26412629

26422630
init(config: AudioConfig) {
@@ -2645,8 +2633,8 @@ private class Gemma3nAudioConformerFeedForward: Module {
26452633
self._postLayerScale = MLXArray(config.confResidualWeight)
26462634

26472635
self._preLayerNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
2648-
self._ffwLayer1.wrappedValue = Linear(config.hiddenSize, config.hiddenSize * 4, bias: false)
2649-
self._ffwLayer2.wrappedValue = Linear(config.hiddenSize * 4, config.hiddenSize, bias: false)
2636+
self._ffwLayer1 = Linear(config.hiddenSize, config.hiddenSize * 4, bias: false)
2637+
self._ffwLayer2 = Linear(config.hiddenSize * 4, config.hiddenSize, bias: false)
26502638
self._postLayerNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
26512639

26522640
super.init()
@@ -2656,9 +2644,9 @@ private class Gemma3nAudioConformerFeedForward: Module {
26562644
let residual = x
26572645
let clippedX = clip(x, min: -_gradientClipping, max: _gradientClipping)
26582646
var result = preLayerNorm(clippedX)
2659-
result = ffwLayer1(result)
2647+
result = _ffwLayer1(result)
26602648
result = silu(result)
2661-
result = ffwLayer2(result)
2649+
result = _ffwLayer2(result)
26622650
let clippedResult = clip(result, min: -_gradientClipping, max: _gradientClipping)
26632651
let normedResult = postLayerNorm(clippedResult)
26642652
return residual + (normedResult * _postLayerScale)
@@ -2737,17 +2725,17 @@ private class Gemma3nAudioConformerLightConv1d: Module {
27372725
// MARK: - Conformer Block
27382726
private class Gemma3nAudioConformerBlock: Module {
27392727
let config: AudioConfig
2740-
private let gradientClipping: MLXArray
2728+
private let _gradientClipping: MLXArray
27412729

2742-
@ModuleInfo var ffwLayerStart: Gemma3nAudioConformerFeedForward
2730+
@ModuleInfo(key: "ffw_layer_start") var ffwLayerStart: Gemma3nAudioConformerFeedForward
27432731
@ModuleInfo var attention: Gemma3nAudioConformerAttention
27442732
@ModuleInfo var lconv1d: Gemma3nAudioConformerLightConv1d
2745-
@ModuleInfo var ffwLayerEnd: Gemma3nAudioConformerFeedForward
2733+
@ModuleInfo(key: "ffw_layer_end") var ffwLayerEnd: Gemma3nAudioConformerFeedForward
27462734
@ModuleInfo var norm: Gemma3nRMSNormWithScale
27472735

27482736
init(config: AudioConfig) {
27492737
self.config = config
2750-
self.gradientClipping = MLXArray(config.gradientClipping)
2738+
self._gradientClipping = MLXArray(config.gradientClipping)
27512739

27522740
self._ffwLayerStart.wrappedValue = Gemma3nAudioConformerFeedForward(config: config)
27532741
self._attention.wrappedValue = Gemma3nAudioConformerAttention(config: config)
@@ -2771,7 +2759,7 @@ private class Gemma3nAudioConformerBlock: Module {
27712759

27722760
result = lconv1d(audioencodingsForLconvInput)
27732761
result = ffwLayerEnd(result)
2774-
result = clip(result, min: -gradientClipping, max: gradientClipping)
2762+
result = clip(result, min: -_gradientClipping, max: _gradientClipping)
27752763
return norm(result)
27762764
}
27772765
}
@@ -2856,7 +2844,8 @@ private func numGroups(groupSize: Int?, channels: Int) -> Int {
28562844
}
28572845
// NOTE: groupSize == 1 -> depthwise conv
28582846
assert(channels % groupSize == 0)
2859-
return channels / groupSize
2847+
let groups = channels / groupSize
2848+
return groups
28602849
}
28612850

28622851
private func makeDivisible(
@@ -3082,6 +3071,7 @@ private class EdgeResidual: Module, UnaryLayer {
30823071
self.hasSkip = (inChannels == outChannels && stride == 1) && !noskip
30833072

30843073
let padding = (expKernelSize - 1) / 2
3074+
30853075
self._convExp.wrappedValue = Conv2d(
30863076
inputChannels: inChannels,
30873077
outputChannels: midChannels,
@@ -3195,6 +3185,7 @@ private class MultiQueryAttention2d: Module {
31953185
groups: dim, // Depthwise
31963186
bias: false
31973187
)
3188+
31983189
self._keyNorm.wrappedValue = RMSNormAct2d(numChannels: dim, eps: 1e-6, applyAct: false)
31993190
} else {
32003191
self._keyDownConv.wrappedValue = Identity()
@@ -3780,37 +3771,23 @@ private class Gemma3nVisionModel: Module {
37803771
}
37813772

37823773
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
3783-
return Self.sanitizeWeights(weights)
3784-
}
3785-
3786-
static func sanitizeWeights(_ weights: [String: MLXArray]) -> [String: MLXArray] {
3787-
var sanitizedWeights = [String: MLXArray]()
3774+
var sanitizedWeights = weights
37883775
var skipTranspose = false
3789-
3790-
// This logic is correct
37913776
let testKey = "vision_tower.timm_model.blocks.0.0.conv_exp.weight"
3792-
if let convWeight = weights[testKey] {
3793-
let shape = convWeight.shape
3794-
if shape.count == 4, shape[3] > shape[1] {
3795-
skipTranspose = true
3796-
}
3777+
if let convWeight = weights[testKey], convWeight.ndim == 4,
3778+
convWeight.shape[3] > convWeight.shape[1]
3779+
{
3780+
skipTranspose = true
37973781
}
3798-
37993782
for (k, v) in weights {
38003783
if (k.contains("conv") && k.contains("weight"))
38013784
|| (k.contains("attn") && k.contains("proj.weight"))
38023785
{
3803-
if v.shape.count == 4 && !skipTranspose {
3786+
if v.ndim == 4 && !skipTranspose {
38043787
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
3805-
} else {
3806-
sanitizedWeights[k] = v
38073788
}
3808-
} else {
3809-
// Copy all other weights (biases, norm layers, etc.)
3810-
sanitizedWeights[k] = v
38113789
}
38123790
}
3813-
38143791
return sanitizedWeights
38153792
}
38163793
}
@@ -3828,8 +3805,9 @@ private class Gemma3nAudioModel: Module {
38283805

38293806
self._subsampleConvProjection.wrappedValue = Gemma3nAudioSubSampleConvProjection(
38303807
config: config)
3831-
self._conformer.wrappedValue = (0 ..< config.confNumHiddenLayers).map { _ in
3832-
Gemma3nAudioConformerBlock(config: config)
3808+
3809+
self._conformer.wrappedValue = (0 ..< config.confNumHiddenLayers).map { i in
3810+
return Gemma3nAudioConformerBlock(config: config)
38333811
}
38343812

38353813
super.init()
@@ -3914,32 +3892,25 @@ private class Gemma3nAudioModel: Module {
39143892
/// Sanitizes weights by transposing convolution layers if they are not
39153893
/// already in the expected MLX format.
39163894
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
3917-
var sanitizedWeights = [String: MLXArray]()
3918-
3895+
var sanitizedWeights = weights
3896+
// Iterate over the original keys to decide which ones to modify in the copy.
39193897
for (k, v) in weights {
39203898
if k.contains("conv.weight") {
3921-
// A Conv2D weight should be 4D.
3922-
// If it is, check if it needs transposing from NCHW to NHWC.
3923-
// If checkArrayShape is true, it's already in the correct format.
3924-
if v.ndim == 4 && !checkArrayShape(v) {
3925-
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
3926-
} else {
3899+
if checkArrayShape(v) {
39273900
sanitizedWeights[k] = v
3901+
} else {
3902+
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
39283903
}
39293904
} else if k.contains("conv1d.weight") {
3930-
// A Conv1D weight should be 3D.
3931-
// If it is, check if it needs transposing from NCL to NLC.
3932-
if v.ndim == 3 && !checkArrayShape(v) {
3933-
sanitizedWeights[k] = v.transposed(0, 2, 1)
3934-
} else {
3905+
if true {
39353906
sanitizedWeights[k] = v
3907+
} else {
3908+
sanitizedWeights[k] = v.transposed(0, 2, 1)
39363909
}
39373910
} else {
3938-
// For all other weights, keep them as they are.
39393911
sanitizedWeights[k] = v
39403912
}
39413913
}
3942-
39433914
return sanitizedWeights
39443915
}
39453916
}

0 commit comments

Comments
 (0)