@@ -1340,33 +1340,27 @@ private class LanguageModel: Module, KVCacheDimensionProvider {
1340
1340
}
1341
1341
1342
1342
func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
1343
- var sanitizedWeights = [ String: MLXArray] ( )
1344
-
1343
+ var sanitizedWeights = weights
1345
1344
for (k, v) in weights {
1346
- // Skip rotary embedding inverse frequency weights (matches Python exactly)
1347
- if k. contains ( " self_attn.rotary_emb.inv_freq " ) {
1348
- continue
1349
- }
1350
- // Python logic: if "language_model.model" not in k and "language_model.lm_head" not in k:
1351
- else if !k. contains ( " language_model.model " ) && !k. contains ( " language_model.lm_head " ) {
1345
+ if !k. contains ( " language_model.model " ) && !k. contains ( " language_model.lm_head " ) {
1346
+ // Transform keys that don't contain the specific patterns
1352
1347
let newKey = k. replacingOccurrences (
1353
1348
of: " language_model " , with: " language_model.model " )
1354
1349
sanitizedWeights [ newKey] = v
1355
- }
1356
- // Otherwise, keep the key as is
1357
- else {
1350
+ } else if k. contains ( " self_attn.rotary_emb.inv_freq " ) {
1351
+ // Skip rotary embedding inverse frequency weights
1352
+ continue
1353
+ } else {
1358
1354
sanitizedWeights [ k] = v
1359
1355
}
1360
1356
}
1361
-
1362
- // If lm_head weight is missing, use embed_tokens weight as fallback (matches Python exactly)
1357
+ // Handle tied lm_head weights
1363
1358
if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
1364
1359
let embedTokensKey = " language_model.model.embed_tokens.weight "
1365
1360
if let embedWeight = sanitizedWeights [ embedTokensKey] {
1366
1361
sanitizedWeights [ " language_model.lm_head.weight " ] = embedWeight
1367
1362
}
1368
1363
}
1369
-
1370
1364
return sanitizedWeights
1371
1365
}
1372
1366
}
@@ -1676,7 +1670,6 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1676
1670
self . _languageModel. wrappedValue = LanguageModel ( config: config. textConfig)
1677
1671
self . _visionTower. wrappedValue = Gemma3nVisionModel ( config: config. visionConfig)
1678
1672
self . _audioTower. wrappedValue = Gemma3nAudioModel ( config: config. audioConfig)
1679
-
1680
1673
self . _embedVision. wrappedValue = Gemma3nMultimodalEmbedder (
1681
1674
multimodalConfig: config. visionConfig,
1682
1675
textConfig: config. textConfig
@@ -1893,20 +1886,16 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1893
1886
return languageModel ( inputs: inputs, cache: convertedCache) . logits
1894
1887
}
1895
1888
1896
- // In class Gemma3n
1897
1889
public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
1898
1890
var sanitizedWeights = [ String: MLXArray] ( )
1899
-
1900
- // Remove the "model." prefix from keys.
1901
1891
for (k, v) in weights {
1902
- if k. hasPrefix ( " model. " ) {
1892
+ if k. starts ( with : " model. " ) {
1903
1893
let newKey = k. split ( separator: " . " ) . dropFirst ( ) . joined ( separator: " . " )
1904
1894
sanitizedWeights [ newKey] = v
1905
1895
} else {
1906
1896
sanitizedWeights [ k] = v
1907
1897
}
1908
1898
}
1909
-
1910
1899
return sanitizedWeights
1911
1900
}
1912
1901
@@ -1937,14 +1926,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1937
1926
weights. merge ( fileWeights) { _, new in new }
1938
1927
}
1939
1928
1940
- // Main sanitization (remove "model." prefix)
1941
1929
var sanitizedWeights = model. sanitize ( weights: weights)
1942
-
1943
- // Vision model sanitization (transpose conv weights)
1944
- sanitizedWeights = Gemma3nVisionModel . sanitizeWeights ( sanitizedWeights)
1945
-
1946
- // Audio model sanitization (transpose conv weights)
1947
- sanitizedWeights = model. audioTower. sanitize ( weights: sanitizedWeights)
1930
+ sanitizedWeights = model. visionTower. sanitize ( weights: sanitizedWeights)
1931
+ // The audio and language sanitization is not done in the Python implementation
1932
+ // sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933
+ // sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
1948
1934
1949
1935
// Handle tied lm_head weights
1950
1936
if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
@@ -1992,7 +1978,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
1992
1978
let maxForward : Int
1993
1979
1994
1980
@ModuleInfo ( key: " pos_proj " ) var posProj : Linear
1995
- @ ModuleInfo ( key : " inv_timescales " ) var invTimescales : MLXArray
1981
+ private let _invTimescales : MLXArray
1996
1982
1997
1983
init ( config: AudioConfig ) {
1998
1984
self . config = config
@@ -2016,7 +2002,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
2016
2002
MLXArray ( 0 ..< numTimescales) . asType ( . float32) * ( - logTimescaleIncrement)
2017
2003
)
2018
2004
2019
- self . _invTimescales. wrappedValue = expandedDimensions (
2005
+ self . _invTimescales = expandedDimensions (
2020
2006
expandedDimensions ( invTimescales, axis: 0 ) ,
2021
2007
axis: 0
2022
2008
)
@@ -2028,7 +2014,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
2028
2014
assert ( position. ndim == 2 )
2029
2015
let positionFloat = expandedDimensions ( position. asType ( . float32) , axis: - 1 )
2030
2016
2031
- let scaledTime = positionFloat * invTimescales
2017
+ let scaledTime = positionFloat * _invTimescales
2032
2018
let timingSignal = concatenated ( [ sin ( scaledTime) , cos ( scaledTime) ] , axis: - 1 )
2033
2019
return timingSignal. asType ( dtype)
2034
2020
}
@@ -2328,6 +2314,7 @@ private class Gemma3nAudioSubSampleConvProjection: Module {
2328
2314
2329
2315
let fInPadded = currentFForBlockInput + padFLeft + padFRight
2330
2316
let fOutAfterConv = ( fInPadded - kernelW) / strideW + 1
2317
+
2331
2318
calculatedFOutDims. append ( fOutAfterConv)
2332
2319
currentFForBlockInput = fOutAfterConv
2333
2320
}
@@ -2389,8 +2376,8 @@ private class Gemma3nAudioAttention: Module {
2389
2376
let attentionLogitsSoftCap : Float
2390
2377
let contextSize : Int
2391
2378
let qScale : Float
2392
- let localCausalValidMask : MLXArray
2393
- let softcap : MLXArray
2379
+ private let _localCausalValidMask : MLXArray
2380
+ private let _softcap : MLXArray
2394
2381
2395
2382
@ModuleInfo ( key: " relative_position_embedding " ) var relativePositionEmbedding :
2396
2383
Gemma3nAudioRelativePositionEmbedding
@@ -2434,9 +2421,10 @@ private class Gemma3nAudioAttention: Module {
2434
2421
)
2435
2422
2436
2423
let localCausalValidMaskTemp = MLXArray . ones ( [ chunkSize, contextSize] , dtype: . bool)
2437
- self . localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask .&& upperCausalMask
2424
+ self . _localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask
2425
+ .&& upperCausalMask
2438
2426
2439
- self . softcap = MLXArray ( attentionLogitsSoftCap, dtype: . float32)
2427
+ self . _softcap = MLXArray ( attentionLogitsSoftCap, dtype: . float32)
2440
2428
2441
2429
super. init ( )
2442
2430
}
@@ -2536,7 +2524,7 @@ private class Gemma3nAudioAttention: Module {
2536
2524
2537
2525
let conditionFromCausality = expandedDimensions (
2538
2526
expandedDimensions (
2539
- expandedDimensions ( localCausalValidMask , axis: 0 ) ,
2527
+ expandedDimensions ( _localCausalValidMask , axis: 0 ) ,
2540
2528
axis: 0
2541
2529
) ,
2542
2530
axis: 0
@@ -2547,9 +2535,9 @@ private class Gemma3nAudioAttention: Module {
2547
2535
var logits = relativePositionEmbedding ( queryBlocks, keyBlocks)
2548
2536
2549
2537
// Apply attention logit softcap
2550
- logits = logits / softcap
2538
+ logits = logits / _softcap
2551
2539
logits = tanh ( logits)
2552
- logits = logits * softcap
2540
+ logits = logits * _softcap
2553
2541
2554
2542
// Apply the combined mask
2555
2543
logits = MLX . where (
@@ -2635,8 +2623,8 @@ private class Gemma3nAudioConformerFeedForward: Module {
2635
2623
private let _postLayerScale : MLXArray
2636
2624
2637
2625
@ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2638
- @ ModuleInfo ( key : " ffw_layer_1 " ) var ffwLayer1 : Linear
2639
- @ ModuleInfo ( key : " ffw_layer_2 " ) var ffwLayer2 : Linear
2626
+ private let _ffwLayer1 : Linear
2627
+ private let _ffwLayer2 : Linear
2640
2628
@ModuleInfo ( key: " post_layer_norm " ) var postLayerNorm : Gemma3nRMSNormWithScale
2641
2629
2642
2630
init ( config: AudioConfig ) {
@@ -2645,8 +2633,8 @@ private class Gemma3nAudioConformerFeedForward: Module {
2645
2633
self . _postLayerScale = MLXArray ( config. confResidualWeight)
2646
2634
2647
2635
self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2648
- self . _ffwLayer1. wrappedValue = Linear ( config. hiddenSize, config. hiddenSize * 4 , bias: false )
2649
- self . _ffwLayer2. wrappedValue = Linear ( config. hiddenSize * 4 , config. hiddenSize, bias: false )
2636
+ self . _ffwLayer1 = Linear ( config. hiddenSize, config. hiddenSize * 4 , bias: false )
2637
+ self . _ffwLayer2 = Linear ( config. hiddenSize * 4 , config. hiddenSize, bias: false )
2650
2638
self . _postLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2651
2639
2652
2640
super. init ( )
@@ -2656,9 +2644,9 @@ private class Gemma3nAudioConformerFeedForward: Module {
2656
2644
let residual = x
2657
2645
let clippedX = clip ( x, min: - _gradientClipping, max: _gradientClipping)
2658
2646
var result = preLayerNorm ( clippedX)
2659
- result = ffwLayer1 ( result)
2647
+ result = _ffwLayer1 ( result)
2660
2648
result = silu ( result)
2661
- result = ffwLayer2 ( result)
2649
+ result = _ffwLayer2 ( result)
2662
2650
let clippedResult = clip ( result, min: - _gradientClipping, max: _gradientClipping)
2663
2651
let normedResult = postLayerNorm ( clippedResult)
2664
2652
return residual + ( normedResult * _postLayerScale)
@@ -2737,22 +2725,22 @@ private class Gemma3nAudioConformerLightConv1d: Module {
2737
2725
// MARK: - Conformer Block
2738
2726
private class Gemma3nAudioConformerBlock : Module {
2739
2727
let config : AudioConfig
2740
- private let gradientClipping : MLXArray
2728
+ private let _gradientClipping : MLXArray
2741
2729
2742
2730
@ModuleInfo var ffwLayerStart : Gemma3nAudioConformerFeedForward
2743
2731
@ModuleInfo var attention : Gemma3nAudioConformerAttention
2744
2732
@ModuleInfo var lconv1d : Gemma3nAudioConformerLightConv1d
2745
- @ ModuleInfo var ffwLayerEnd : Gemma3nAudioConformerFeedForward
2733
+ private let _ffwLayerEnd : Gemma3nAudioConformerFeedForward
2746
2734
@ModuleInfo var norm : Gemma3nRMSNormWithScale
2747
2735
2748
2736
init ( config: AudioConfig ) {
2749
2737
self . config = config
2750
- self . gradientClipping = MLXArray ( config. gradientClipping)
2738
+ self . _gradientClipping = MLXArray ( config. gradientClipping)
2751
2739
2752
2740
self . _ffwLayerStart. wrappedValue = Gemma3nAudioConformerFeedForward ( config: config)
2753
2741
self . _attention. wrappedValue = Gemma3nAudioConformerAttention ( config: config)
2754
2742
self . _lconv1d. wrappedValue = Gemma3nAudioConformerLightConv1d ( config: config)
2755
- self . _ffwLayerEnd. wrappedValue = Gemma3nAudioConformerFeedForward ( config: config)
2743
+ self . _ffwLayerEnd = Gemma3nAudioConformerFeedForward ( config: config)
2756
2744
self . _norm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2757
2745
2758
2746
super. init ( )
@@ -2770,8 +2758,8 @@ private class Gemma3nAudioConformerBlock: Module {
2770
2758
) . asType ( result. dtype)
2771
2759
2772
2760
result = lconv1d ( audioencodingsForLconvInput)
2773
- result = ffwLayerEnd ( result)
2774
- result = clip ( result, min: - gradientClipping , max: gradientClipping )
2761
+ result = _ffwLayerEnd ( result)
2762
+ result = clip ( result, min: - _gradientClipping , max: _gradientClipping )
2775
2763
return norm ( result)
2776
2764
}
2777
2765
}
@@ -2856,7 +2844,8 @@ private func numGroups(groupSize: Int?, channels: Int) -> Int {
2856
2844
}
2857
2845
// NOTE: groupSize == 1 -> depthwise conv
2858
2846
assert ( channels % groupSize == 0 )
2859
- return channels / groupSize
2847
+ let groups = channels / groupSize
2848
+ return groups
2860
2849
}
2861
2850
2862
2851
private func makeDivisible(
@@ -3082,6 +3071,7 @@ private class EdgeResidual: Module, UnaryLayer {
3082
3071
self . hasSkip = ( inChannels == outChannels && stride == 1 ) && !noskip
3083
3072
3084
3073
let padding = ( expKernelSize - 1 ) / 2
3074
+
3085
3075
self . _convExp. wrappedValue = Conv2d (
3086
3076
inputChannels: inChannels,
3087
3077
outputChannels: midChannels,
@@ -3195,6 +3185,7 @@ private class MultiQueryAttention2d: Module {
3195
3185
groups: dim, // Depthwise
3196
3186
bias: false
3197
3187
)
3188
+
3198
3189
self . _keyNorm. wrappedValue = RMSNormAct2d ( numChannels: dim, eps: 1e-6 , applyAct: false )
3199
3190
} else {
3200
3191
self . _keyDownConv. wrappedValue = Identity ( )
@@ -3780,37 +3771,23 @@ private class Gemma3nVisionModel: Module {
3780
3771
}
3781
3772
3782
3773
func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3783
- return Self . sanitizeWeights ( weights)
3784
- }
3785
-
3786
- static func sanitizeWeights( _ weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3787
- var sanitizedWeights = [ String: MLXArray] ( )
3774
+ var sanitizedWeights = weights
3788
3775
var skipTranspose = false
3789
-
3790
- // This logic is correct
3791
3776
let testKey = " vision_tower.timm_model.blocks.0.0.conv_exp.weight "
3792
- if let convWeight = weights [ testKey] {
3793
- let shape = convWeight. shape
3794
- if shape. count == 4 , shape [ 3 ] > shape [ 1 ] {
3795
- skipTranspose = true
3796
- }
3777
+ if let convWeight = weights [ testKey] , convWeight. ndim == 4 ,
3778
+ convWeight. shape [ 3 ] > convWeight. shape [ 1 ]
3779
+ {
3780
+ skipTranspose = true
3797
3781
}
3798
-
3799
3782
for (k, v) in weights {
3800
3783
if ( k. contains ( " conv " ) && k. contains ( " weight " ) )
3801
3784
|| ( k. contains ( " attn " ) && k. contains ( " proj.weight " ) )
3802
3785
{
3803
- if v. shape . count == 4 && !skipTranspose {
3786
+ if v. ndim == 4 && !skipTranspose {
3804
3787
sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3805
- } else {
3806
- sanitizedWeights [ k] = v
3807
3788
}
3808
- } else {
3809
- // Copy all other weights (biases, norm layers, etc.)
3810
- sanitizedWeights [ k] = v
3811
3789
}
3812
3790
}
3813
-
3814
3791
return sanitizedWeights
3815
3792
}
3816
3793
}
@@ -3828,8 +3805,9 @@ private class Gemma3nAudioModel: Module {
3828
3805
3829
3806
self . _subsampleConvProjection. wrappedValue = Gemma3nAudioSubSampleConvProjection (
3830
3807
config: config)
3831
- self . _conformer. wrappedValue = ( 0 ..< config. confNumHiddenLayers) . map { _ in
3832
- Gemma3nAudioConformerBlock ( config: config)
3808
+
3809
+ self . _conformer. wrappedValue = ( 0 ..< config. confNumHiddenLayers) . map { i in
3810
+ return Gemma3nAudioConformerBlock ( config: config)
3833
3811
}
3834
3812
3835
3813
super. init ( )
@@ -3914,32 +3892,25 @@ private class Gemma3nAudioModel: Module {
3914
3892
/// Sanitizes weights by transposing convolution layers if they are not
3915
3893
/// already in the expected MLX format.
3916
3894
func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3917
- var sanitizedWeights = [ String : MLXArray ] ( )
3918
-
3895
+ var sanitizedWeights = weights
3896
+ // Iterate over the original keys to decide which ones to modify in the copy.
3919
3897
for (k, v) in weights {
3920
3898
if k. contains ( " conv.weight " ) {
3921
- // A Conv2D weight should be 4D.
3922
- // If it is, check if it needs transposing from NCHW to NHWC.
3923
- // If checkArrayShape is true, it's already in the correct format.
3924
- if v. ndim == 4 && !checkArrayShape( v) {
3925
- sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3926
- } else {
3899
+ if checkArrayShape ( v) {
3927
3900
sanitizedWeights [ k] = v
3901
+ } else {
3902
+ sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3928
3903
}
3929
3904
} else if k. contains ( " conv1d.weight " ) {
3930
- // A Conv1D weight should be 3D.
3931
- // If it is, check if it needs transposing from NCL to NLC.
3932
- if v. ndim == 3 && !checkArrayShape( v) {
3933
- sanitizedWeights [ k] = v. transposed ( 0 , 2 , 1 )
3934
- } else {
3905
+ if true {
3935
3906
sanitizedWeights [ k] = v
3907
+ } else {
3908
+ sanitizedWeights [ k] = v. transposed ( 0 , 2 , 1 )
3936
3909
}
3937
3910
} else {
3938
- // For all other weights, keep them as they are.
3939
3911
sanitizedWeights [ k] = v
3940
3912
}
3941
3913
}
3942
-
3943
3914
return sanitizedWeights
3944
3915
}
3945
3916
}
0 commit comments