@@ -1340,33 +1340,27 @@ private class LanguageModel: Module, KVCacheDimensionProvider {
1340
1340
}
1341
1341
1342
1342
func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
1343
- var sanitizedWeights = [ String: MLXArray] ( )
1344
-
1343
+ var sanitizedWeights = weights
1345
1344
for (k, v) in weights {
1346
- // Skip rotary embedding inverse frequency weights (matches Python exactly)
1347
- if k. contains ( " self_attn.rotary_emb.inv_freq " ) {
1348
- continue
1349
- }
1350
- // Python logic: if "language_model.model" not in k and "language_model.lm_head" not in k:
1351
- else if !k. contains ( " language_model.model " ) && !k. contains ( " language_model.lm_head " ) {
1345
+ if !k. contains ( " language_model.model " ) && !k. contains ( " language_model.lm_head " ) {
1346
+ // Transform keys that don't contain the specific patterns
1352
1347
let newKey = k. replacingOccurrences (
1353
1348
of: " language_model " , with: " language_model.model " )
1354
1349
sanitizedWeights [ newKey] = v
1355
- }
1356
- // Otherwise, keep the key as is
1357
- else {
1350
+ } else if k. contains ( " self_attn.rotary_emb.inv_freq " ) {
1351
+ // Skip rotary embedding inverse frequency weights
1352
+ continue
1353
+ } else {
1358
1354
sanitizedWeights [ k] = v
1359
1355
}
1360
1356
}
1361
-
1362
- // If lm_head weight is missing, use embed_tokens weight as fallback (matches Python exactly)
1357
+ // Handle tied lm_head weights
1363
1358
if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
1364
1359
let embedTokensKey = " language_model.model.embed_tokens.weight "
1365
1360
if let embedWeight = sanitizedWeights [ embedTokensKey] {
1366
1361
sanitizedWeights [ " language_model.lm_head.weight " ] = embedWeight
1367
1362
}
1368
1363
}
1369
-
1370
1364
return sanitizedWeights
1371
1365
}
1372
1366
}
@@ -1676,7 +1670,6 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1676
1670
self . _languageModel. wrappedValue = LanguageModel ( config: config. textConfig)
1677
1671
self . _visionTower. wrappedValue = Gemma3nVisionModel ( config: config. visionConfig)
1678
1672
self . _audioTower. wrappedValue = Gemma3nAudioModel ( config: config. audioConfig)
1679
-
1680
1673
self . _embedVision. wrappedValue = Gemma3nMultimodalEmbedder (
1681
1674
multimodalConfig: config. visionConfig,
1682
1675
textConfig: config. textConfig
@@ -1893,20 +1886,16 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1893
1886
return languageModel ( inputs: inputs, cache: convertedCache) . logits
1894
1887
}
1895
1888
1896
- // In class Gemma3n
1897
1889
public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
1898
1890
var sanitizedWeights = [ String: MLXArray] ( )
1899
-
1900
- // Remove the "model." prefix from keys.
1901
1891
for (k, v) in weights {
1902
- if k. hasPrefix ( " model. " ) {
1892
+ if k. starts ( with : " model. " ) {
1903
1893
let newKey = k. split ( separator: " . " ) . dropFirst ( ) . joined ( separator: " . " )
1904
1894
sanitizedWeights [ newKey] = v
1905
1895
} else {
1906
1896
sanitizedWeights [ k] = v
1907
1897
}
1908
1898
}
1909
-
1910
1899
return sanitizedWeights
1911
1900
}
1912
1901
@@ -1937,14 +1926,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1937
1926
weights. merge ( fileWeights) { _, new in new }
1938
1927
}
1939
1928
1940
- // Main sanitization (remove "model." prefix)
1941
1929
var sanitizedWeights = model. sanitize ( weights: weights)
1942
-
1943
- // Vision model sanitization (transpose conv weights)
1944
- sanitizedWeights = Gemma3nVisionModel . sanitizeWeights ( sanitizedWeights)
1945
-
1946
- // Audio model sanitization (transpose conv weights)
1947
- sanitizedWeights = model. audioTower. sanitize ( weights: sanitizedWeights)
1930
+ sanitizedWeights = model. visionTower. sanitize ( weights: sanitizedWeights)
1931
+ // The audio and language sanitization is not done in the Python implementation
1932
+ // sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933
+ // sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
1948
1934
1949
1935
// Handle tied lm_head weights
1950
1936
if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
@@ -1992,7 +1978,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
1992
1978
let maxForward : Int
1993
1979
1994
1980
@ModuleInfo ( key: " pos_proj " ) var posProj : Linear
1995
- @ ModuleInfo ( key : " inv_timescales " ) var invTimescales : MLXArray
1981
+ private let _invTimescales : MLXArray
1996
1982
1997
1983
init ( config: AudioConfig ) {
1998
1984
self . config = config
@@ -2016,7 +2002,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
2016
2002
MLXArray ( 0 ..< numTimescales) . asType ( . float32) * ( - logTimescaleIncrement)
2017
2003
)
2018
2004
2019
- self . _invTimescales. wrappedValue = expandedDimensions (
2005
+ self . _invTimescales = expandedDimensions (
2020
2006
expandedDimensions ( invTimescales, axis: 0 ) ,
2021
2007
axis: 0
2022
2008
)
@@ -2028,7 +2014,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
2028
2014
assert ( position. ndim == 2 )
2029
2015
let positionFloat = expandedDimensions ( position. asType ( . float32) , axis: - 1 )
2030
2016
2031
- let scaledTime = positionFloat * invTimescales
2017
+ let scaledTime = positionFloat * _invTimescales
2032
2018
let timingSignal = concatenated ( [ sin ( scaledTime) , cos ( scaledTime) ] , axis: - 1 )
2033
2019
return timingSignal. asType ( dtype)
2034
2020
}
@@ -2328,6 +2314,7 @@ private class Gemma3nAudioSubSampleConvProjection: Module {
2328
2314
2329
2315
let fInPadded = currentFForBlockInput + padFLeft + padFRight
2330
2316
let fOutAfterConv = ( fInPadded - kernelW) / strideW + 1
2317
+
2331
2318
calculatedFOutDims. append ( fOutAfterConv)
2332
2319
currentFForBlockInput = fOutAfterConv
2333
2320
}
@@ -2389,8 +2376,8 @@ private class Gemma3nAudioAttention: Module {
2389
2376
let attentionLogitsSoftCap : Float
2390
2377
let contextSize : Int
2391
2378
let qScale : Float
2392
- let localCausalValidMask : MLXArray
2393
- let softcap : MLXArray
2379
+ private let _localCausalValidMask : MLXArray
2380
+ private let _softcap : MLXArray
2394
2381
2395
2382
@ModuleInfo ( key: " relative_position_embedding " ) var relativePositionEmbedding :
2396
2383
Gemma3nAudioRelativePositionEmbedding
@@ -2434,9 +2421,10 @@ private class Gemma3nAudioAttention: Module {
2434
2421
)
2435
2422
2436
2423
let localCausalValidMaskTemp = MLXArray . ones ( [ chunkSize, contextSize] , dtype: . bool)
2437
- self . localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask .&& upperCausalMask
2424
+ self . _localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask
2425
+ .&& upperCausalMask
2438
2426
2439
- self . softcap = MLXArray ( attentionLogitsSoftCap, dtype: . float32)
2427
+ self . _softcap = MLXArray ( attentionLogitsSoftCap, dtype: . float32)
2440
2428
2441
2429
super. init ( )
2442
2430
}
@@ -2536,7 +2524,7 @@ private class Gemma3nAudioAttention: Module {
2536
2524
2537
2525
let conditionFromCausality = expandedDimensions (
2538
2526
expandedDimensions (
2539
- expandedDimensions ( localCausalValidMask , axis: 0 ) ,
2527
+ expandedDimensions ( _localCausalValidMask , axis: 0 ) ,
2540
2528
axis: 0
2541
2529
) ,
2542
2530
axis: 0
@@ -2547,9 +2535,9 @@ private class Gemma3nAudioAttention: Module {
2547
2535
var logits = relativePositionEmbedding ( queryBlocks, keyBlocks)
2548
2536
2549
2537
// Apply attention logit softcap
2550
- logits = logits / softcap
2538
+ logits = logits / _softcap
2551
2539
logits = tanh ( logits)
2552
- logits = logits * softcap
2540
+ logits = logits * _softcap
2553
2541
2554
2542
// Apply the combined mask
2555
2543
logits = MLX . where (
@@ -2591,10 +2579,10 @@ private class Gemma3nAudioConformerAttention: Module {
2591
2579
let postInFeatures : Int
2592
2580
private let _gradientClipping : MLXArray
2593
2581
2594
- @ModuleInfo var preAttnNorm : Gemma3nRMSNormWithScale
2582
+ @ModuleInfo ( key : " pre_attn_norm " ) var preAttnNorm : Gemma3nRMSNormWithScale
2595
2583
@ModuleInfo var attn : Gemma3nAudioAttention
2596
2584
@ModuleInfo var post : Linear
2597
- @ModuleInfo var postNorm : Gemma3nRMSNormWithScale
2585
+ @ModuleInfo ( key : " post_norm " ) var postNorm : Gemma3nRMSNormWithScale
2598
2586
2599
2587
init ( config: AudioConfig ) {
2600
2588
self . config = config
@@ -2737,17 +2725,17 @@ private class Gemma3nAudioConformerLightConv1d: Module {
2737
2725
// MARK: - Conformer Block
2738
2726
private class Gemma3nAudioConformerBlock : Module {
2739
2727
let config : AudioConfig
2740
- private let gradientClipping : MLXArray
2728
+ private let _gradientClipping : MLXArray
2741
2729
2742
- @ModuleInfo var ffwLayerStart : Gemma3nAudioConformerFeedForward
2730
+ @ModuleInfo ( key : " ffw_layer_start " ) var ffwLayerStart : Gemma3nAudioConformerFeedForward
2743
2731
@ModuleInfo var attention : Gemma3nAudioConformerAttention
2744
2732
@ModuleInfo var lconv1d : Gemma3nAudioConformerLightConv1d
2745
- @ModuleInfo var ffwLayerEnd : Gemma3nAudioConformerFeedForward
2733
+ @ModuleInfo ( key : " ffw_layer_end " ) var ffwLayerEnd : Gemma3nAudioConformerFeedForward
2746
2734
@ModuleInfo var norm : Gemma3nRMSNormWithScale
2747
2735
2748
2736
init ( config: AudioConfig ) {
2749
2737
self . config = config
2750
- self . gradientClipping = MLXArray ( config. gradientClipping)
2738
+ self . _gradientClipping = MLXArray ( config. gradientClipping)
2751
2739
2752
2740
self . _ffwLayerStart. wrappedValue = Gemma3nAudioConformerFeedForward ( config: config)
2753
2741
self . _attention. wrappedValue = Gemma3nAudioConformerAttention ( config: config)
@@ -2771,7 +2759,7 @@ private class Gemma3nAudioConformerBlock: Module {
2771
2759
2772
2760
result = lconv1d ( audioencodingsForLconvInput)
2773
2761
result = ffwLayerEnd ( result)
2774
- result = clip ( result, min: - gradientClipping , max: gradientClipping )
2762
+ result = clip ( result, min: - _gradientClipping , max: _gradientClipping )
2775
2763
return norm ( result)
2776
2764
}
2777
2765
}
@@ -2856,7 +2844,8 @@ private func numGroups(groupSize: Int?, channels: Int) -> Int {
2856
2844
}
2857
2845
// NOTE: groupSize == 1 -> depthwise conv
2858
2846
assert ( channels % groupSize == 0 )
2859
- return channels / groupSize
2847
+ let groups = channels / groupSize
2848
+ return groups
2860
2849
}
2861
2850
2862
2851
private func makeDivisible(
@@ -3082,6 +3071,7 @@ private class EdgeResidual: Module, UnaryLayer {
3082
3071
self . hasSkip = ( inChannels == outChannels && stride == 1 ) && !noskip
3083
3072
3084
3073
let padding = ( expKernelSize - 1 ) / 2
3074
+
3085
3075
self . _convExp. wrappedValue = Conv2d (
3086
3076
inputChannels: inChannels,
3087
3077
outputChannels: midChannels,
@@ -3139,17 +3129,17 @@ private class MultiQueryAttention2d: Module {
3139
3129
let valueDim : Int
3140
3130
let scale : Float
3141
3131
3142
- @ModuleInfo var queryProj : Conv2d
3132
+ @ModuleInfo ( key : " query_proj " ) var queryProj : Conv2d
3143
3133
3144
- @ModuleInfo var keyDownConv : UnaryLayer
3145
- @ModuleInfo var keyNorm : UnaryLayer
3146
- @ModuleInfo var valueDownConv : UnaryLayer
3147
- @ModuleInfo var valueNorm : UnaryLayer
3134
+ @ModuleInfo ( key : " key_down_conv " ) var keyDownConv : UnaryLayer
3135
+ @ModuleInfo ( key : " key_norm " ) var keyNorm : UnaryLayer
3136
+ @ModuleInfo ( key : " value_down_conv " ) var valueDownConv : UnaryLayer
3137
+ @ModuleInfo ( key : " value_norm " ) var valueNorm : UnaryLayer
3148
3138
3149
- @ModuleInfo var keyProj : Conv2d
3150
- @ModuleInfo var valueProj : Conv2d
3139
+ @ModuleInfo ( key : " key_proj " ) var keyProj : Conv2d
3140
+ @ModuleInfo ( key : " value_proj " ) var valueProj : Conv2d
3151
3141
@ModuleInfo ( key: " attn_drop " ) var attnDrop : UnaryLayer
3152
- @ModuleInfo var outputProj : Conv2d
3142
+ @ModuleInfo ( key : " output_proj " ) var outputProj : Conv2d
3153
3143
@ModuleInfo ( key: " proj_drop " ) var projDrop : UnaryLayer
3154
3144
3155
3145
init (
@@ -3195,6 +3185,7 @@ private class MultiQueryAttention2d: Module {
3195
3185
groups: dim, // Depthwise
3196
3186
bias: false
3197
3187
)
3188
+
3198
3189
self . _keyNorm. wrappedValue = RMSNormAct2d ( numChannels: dim, eps: 1e-6 , applyAct: false )
3199
3190
} else {
3200
3191
self . _keyDownConv. wrappedValue = Identity ( )
@@ -3323,8 +3314,8 @@ private class MobileAttention: Module, UnaryLayer {
3323
3314
3324
3315
@ModuleInfo var norm : RMSNormAct2d
3325
3316
@ModuleInfo var attn : MultiQueryAttention2d
3326
- @ModuleInfo var layerScale : UnaryLayer
3327
- @ModuleInfo var dropPath : Identity
3317
+ @ModuleInfo ( key : " layer_scale " ) var layerScale : UnaryLayer
3318
+ @ModuleInfo ( key : " drop_path " ) var dropPath : Identity
3328
3319
3329
3320
init (
3330
3321
inChannels: Int ,
@@ -3544,7 +3535,7 @@ private class MobileNetV5MultiScaleFusionAdapter: Module {
3544
3535
3545
3536
@ModuleInfo var ffn : UniversalInvertedResidual
3546
3537
@ModuleInfo var norm : RMSNormAct2d
3547
- @ModuleInfo var avgPool : AvgPool2d
3538
+ @ModuleInfo ( key : " avg_pool " ) var avgPool : AvgPool2d
3548
3539
3549
3540
init (
3550
3541
inChannels: [ Int ] ,
@@ -3780,37 +3771,23 @@ private class Gemma3nVisionModel: Module {
3780
3771
}
3781
3772
3782
3773
func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3783
- return Self . sanitizeWeights ( weights)
3784
- }
3785
-
3786
- static func sanitizeWeights( _ weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3787
- var sanitizedWeights = [ String: MLXArray] ( )
3774
+ var sanitizedWeights = weights
3788
3775
var skipTranspose = false
3789
-
3790
- // This logic is correct
3791
3776
let testKey = " vision_tower.timm_model.blocks.0.0.conv_exp.weight "
3792
- if let convWeight = weights [ testKey] {
3793
- let shape = convWeight. shape
3794
- if shape. count == 4 , shape [ 3 ] > shape [ 1 ] {
3795
- skipTranspose = true
3796
- }
3777
+ if let convWeight = weights [ testKey] , convWeight. ndim == 4 ,
3778
+ convWeight. shape [ 3 ] > convWeight. shape [ 1 ]
3779
+ {
3780
+ skipTranspose = true
3797
3781
}
3798
-
3799
3782
for (k, v) in weights {
3800
3783
if ( k. contains ( " conv " ) && k. contains ( " weight " ) )
3801
3784
|| ( k. contains ( " attn " ) && k. contains ( " proj.weight " ) )
3802
3785
{
3803
- if v. shape . count == 4 && !skipTranspose {
3786
+ if v. ndim == 4 && !skipTranspose {
3804
3787
sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3805
- } else {
3806
- sanitizedWeights [ k] = v
3807
3788
}
3808
- } else {
3809
- // Copy all other weights (biases, norm layers, etc.)
3810
- sanitizedWeights [ k] = v
3811
3789
}
3812
3790
}
3813
-
3814
3791
return sanitizedWeights
3815
3792
}
3816
3793
}
@@ -3828,8 +3805,9 @@ private class Gemma3nAudioModel: Module {
3828
3805
3829
3806
self . _subsampleConvProjection. wrappedValue = Gemma3nAudioSubSampleConvProjection (
3830
3807
config: config)
3831
- self . _conformer. wrappedValue = ( 0 ..< config. confNumHiddenLayers) . map { _ in
3832
- Gemma3nAudioConformerBlock ( config: config)
3808
+
3809
+ self . _conformer. wrappedValue = ( 0 ..< config. confNumHiddenLayers) . map { i in
3810
+ return Gemma3nAudioConformerBlock ( config: config)
3833
3811
}
3834
3812
3835
3813
super. init ( )
@@ -3914,32 +3892,25 @@ private class Gemma3nAudioModel: Module {
3914
3892
/// Sanitizes weights by transposing convolution layers if they are not
3915
3893
/// already in the expected MLX format.
3916
3894
func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3917
- var sanitizedWeights = [ String : MLXArray ] ( )
3918
-
3895
+ var sanitizedWeights = weights
3896
+ // Iterate over the original keys to decide which ones to modify in the copy.
3919
3897
for (k, v) in weights {
3920
3898
if k. contains ( " conv.weight " ) {
3921
- // A Conv2D weight should be 4D.
3922
- // If it is, check if it needs transposing from NCHW to NHWC.
3923
- // If checkArrayShape is true, it's already in the correct format.
3924
- if v. ndim == 4 && !checkArrayShape( v) {
3925
- sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3926
- } else {
3899
+ if checkArrayShape ( v) {
3927
3900
sanitizedWeights [ k] = v
3901
+ } else {
3902
+ sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3928
3903
}
3929
3904
} else if k. contains ( " conv1d.weight " ) {
3930
- // A Conv1D weight should be 3D.
3931
- // If it is, check if it needs transposing from NCL to NLC.
3932
- if v. ndim == 3 && !checkArrayShape( v) {
3933
- sanitizedWeights [ k] = v. transposed ( 0 , 2 , 1 )
3934
- } else {
3905
+ if true {
3935
3906
sanitizedWeights [ k] = v
3907
+ } else {
3908
+ sanitizedWeights [ k] = v. transposed ( 0 , 2 , 1 )
3936
3909
}
3937
3910
} else {
3938
- // For all other weights, keep them as they are.
3939
3911
sanitizedWeights [ k] = v
3940
3912
}
3941
3913
}
3942
-
3943
3914
return sanitizedWeights
3944
3915
}
3945
3916
}
0 commit comments