@@ -495,8 +495,8 @@ private class Gemma3nRotaryEmbedding: Module {
495
495
let originalMaxSeqLen : Int
496
496
let config : TextConfig
497
497
let attentionScaling : Float
498
- let invFreq : MLXArray
499
- let originalInvFreq : MLXArray
498
+ let _invFreq : MLXArray
499
+ let _originalInvFreq : MLXArray
500
500
501
501
init ( config: TextConfig ) {
502
502
if let ropeScaling = config. ropeScaling {
@@ -516,8 +516,8 @@ private class Gemma3nRotaryEmbedding: Module {
516
516
self . attentionScaling = 1.0
517
517
518
518
let ( invFreq, _) = Self . computeDefaultRopeParameters ( config: config)
519
- self . invFreq = MLXArray ( invFreq) . asType ( . float32)
520
- self . originalInvFreq = MLXArray ( invFreq) . asType ( . float32)
519
+ self . _invFreq = MLXArray ( invFreq) . asType ( . float32)
520
+ self . _originalInvFreq = MLXArray ( invFreq) . asType ( . float32)
521
521
522
522
super. init ( )
523
523
}
@@ -538,7 +538,7 @@ private class Gemma3nRotaryEmbedding: Module {
538
538
}
539
539
540
540
func callAsFunction( _ x: MLXArray , positionIds: MLXArray ) -> ( MLXArray , MLXArray ) {
541
- let invFreqExpanded = expandedDimensions ( invFreq , axes: [ 0 , 2 ] )
541
+ let invFreqExpanded = expandedDimensions ( _invFreq , axes: [ 0 , 2 ] )
542
542
let positionIdsExpanded = expandedDimensions ( positionIds. asType ( . float32) , axes: [ 1 ] )
543
543
544
544
let freqs = matmul (
@@ -750,7 +750,7 @@ private class Gemma3nAltUp: Module {
750
750
@ModuleInfo ( key: " prediction_coefs " ) var predictionCoefs : Linear
751
751
@ModuleInfo ( key: " modality_router " ) var modalityRouter : Linear
752
752
@ModuleInfo ( key: " router_norm " ) var routerNorm : Gemma3nRMSNormWithScale
753
- @ ModuleInfo ( key : " router_input_scale " ) var routerInputScale : MLXArray
753
+ let _routerInputScale : MLXArray
754
754
755
755
let config : TextConfig
756
756
@@ -778,14 +778,14 @@ private class Gemma3nAltUp: Module {
778
778
eps: config. rmsNormEps,
779
779
scaleShift: 0.0
780
780
)
781
- self . _routerInputScale. wrappedValue = MLXArray ( pow ( Float ( config. hiddenSize) , - 1.0 ) )
781
+ self . _routerInputScale = MLXArray ( pow ( Float ( config. hiddenSize) , - 1.0 ) )
782
782
783
783
super. init ( )
784
784
}
785
785
786
786
func computeRouterModalities( _ x: MLXArray ) -> MLXArray {
787
787
let routerInputs =
788
- routerNorm ( x) * routerInputScale . asType ( routerNorm. weight. dtype)
788
+ routerNorm ( x) * _routerInputScale . asType ( routerNorm. weight. dtype)
789
789
let routed = modalityRouter ( routerInputs) . asType ( . float32)
790
790
return tanh ( routed)
791
791
}
@@ -1057,8 +1057,8 @@ private class Gemma3Model: Module {
1057
1057
let vocabSize : Int
1058
1058
let vocabSizePerLayerInput : Int
1059
1059
let numHiddenLayers : Int
1060
- private let perLayerProjectionScale : MLXArray
1061
- private let perLayerInputScale : MLXArray
1060
+ private let _perLayerProjectionScale : MLXArray
1061
+ private let _perLayerInputScale : MLXArray
1062
1062
1063
1063
@ModuleInfo ( key: " embed_tokens " ) var embedTokens : Gemma3nTextScaledWordEmbedding
1064
1064
@ModuleInfo ( key: " layers " ) var layers : [ Gemma3nDecoderLayer ] // This is correct!
@@ -1125,8 +1125,8 @@ private class Gemma3Model: Module {
1125
1125
scaleShift: 0.0
1126
1126
)
1127
1127
1128
- self . perLayerProjectionScale = MLXArray ( pow ( Float ( hiddenSize) , - 0.5 ) )
1129
- self . perLayerInputScale = rsqrt ( MLXArray ( 2.0 ) )
1128
+ self . _perLayerProjectionScale = MLXArray ( pow ( Float ( hiddenSize) , - 0.5 ) )
1129
+ self . _perLayerInputScale = rsqrt ( MLXArray ( 2.0 ) )
1130
1130
1131
1131
self . _ropeEmbedding. wrappedValue = Gemma3nRotaryEmbedding ( config: config)
1132
1132
@@ -1261,7 +1261,8 @@ private class Gemma3Model: Module {
1261
1261
1262
1262
func projectPerLayerInputs( _ inputsEmbeds: MLXArray , perLayerInputs: MLXArray ? ) -> MLXArray {
1263
1263
var perLayerProjection = perLayerModelProjection ( inputsEmbeds)
1264
- perLayerProjection = perLayerProjection * perLayerProjectionScale. asType ( inputsEmbeds. dtype)
1264
+ perLayerProjection =
1265
+ perLayerProjection * _perLayerProjectionScale. asType ( inputsEmbeds. dtype)
1265
1266
1266
1267
perLayerProjection = perLayerProjection. reshaped (
1267
1268
Array ( inputsEmbeds. shape. dropLast ( ) ) + [
@@ -1282,7 +1283,7 @@ private class Gemma3Model: Module {
1282
1283
}
1283
1284
1284
1285
return ( perLayerProjection + adjustedPerLayerInputs)
1285
- * perLayerInputScale . asType ( inputsEmbeds. dtype)
1286
+ * _perLayerInputScale . asType ( inputsEmbeds. dtype)
1286
1287
}
1287
1288
}
1288
1289
@@ -2335,8 +2336,8 @@ private class Gemma3nAudioSubSampleConvProjection: Module {
2335
2336
let config : AudioConfig
2336
2337
let inputProjInFeatures : Int
2337
2338
2338
- @ModuleInfo var conv0 : Gemma3nAudioSSCPConvBlock
2339
- @ModuleInfo var conv1 : Gemma3nAudioSSCPConvBlock
2339
+ @ModuleInfo ( key : " conv_0 " ) var conv0 : Gemma3nAudioSSCPConvBlock
2340
+ @ModuleInfo ( key : " conv_1 " ) var conv1 : Gemma3nAudioSSCPConvBlock
2340
2341
@ModuleInfo ( key: " input_proj_linear " ) var inputProjLinear : Linear
2341
2342
2342
2343
init ( config: AudioConfig ) {
@@ -2625,7 +2626,7 @@ private class Gemma3nAudioAttention: Module {
2625
2626
private class Gemma3nAudioConformerAttention : Module {
2626
2627
let config : AudioConfig
2627
2628
let postInFeatures : Int
2628
- private let gradientClipping : MLXArray
2629
+ private let _gradientClipping : MLXArray
2629
2630
2630
2631
@ModuleInfo var preAttnNorm : Gemma3nRMSNormWithScale
2631
2632
@ModuleInfo var attn : Gemma3nAudioAttention
@@ -2636,7 +2637,7 @@ private class Gemma3nAudioConformerAttention: Module {
2636
2637
self . config = config
2637
2638
let headDim = config. hiddenSize / config. confNumAttentionHeads
2638
2639
self . postInFeatures = config. hiddenSize
2639
- self . gradientClipping = MLXArray ( config. gradientClipping)
2640
+ self . _gradientClipping = MLXArray ( config. gradientClipping)
2640
2641
2641
2642
self . _preAttnNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2642
2643
self . _attn. wrappedValue = Gemma3nAudioAttention ( config: config)
@@ -2648,7 +2649,7 @@ private class Gemma3nAudioConformerAttention: Module {
2648
2649
2649
2650
func callAsFunction( _ x: MLXArray , mask: MLXArray ) -> MLXArray {
2650
2651
let audioencodingsInputToAttn = x
2651
- let clippedX = clip ( x, min: - gradientClipping , max: gradientClipping )
2652
+ let clippedX = clip ( x, min: - _gradientClipping , max: _gradientClipping )
2652
2653
let audioencodingsNorm = preAttnNorm ( clippedX)
2653
2654
let audioencodingsAttnOut = attn ( audioencodingsNorm, mask: mask)
2654
2655
@@ -2659,26 +2660,26 @@ private class Gemma3nAudioConformerAttention: Module {
2659
2660
let audioencodingsReshaped = audioencodingsAttnOut. reshaped ( [ b, t, numHeads * headDim] )
2660
2661
2661
2662
let postResult = post ( audioencodingsReshaped)
2662
- let clippedPost = clip ( postResult, min: - gradientClipping , max: gradientClipping )
2663
+ let clippedPost = clip ( postResult, min: - _gradientClipping , max: _gradientClipping )
2663
2664
return audioencodingsInputToAttn + postNorm( clippedPost)
2664
2665
}
2665
2666
}
2666
2667
2667
2668
// MARK: - Conformer Feed Forward
2668
2669
private class Gemma3nAudioConformerFeedForward : Module {
2669
2670
let config : AudioConfig
2670
- private let gradientClipping : MLXArray
2671
- private let postLayerScale : MLXArray
2671
+ private let _gradientClipping : MLXArray
2672
+ private let _postLayerScale : MLXArray
2672
2673
2673
- @ModuleInfo var preLayerNorm : Gemma3nRMSNormWithScale
2674
- @ModuleInfo var ffwLayer1 : Linear
2675
- @ModuleInfo var ffwLayer2 : Linear
2676
- @ModuleInfo var postLayerNorm : Gemma3nRMSNormWithScale
2674
+ @ModuleInfo ( key : " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2675
+ @ModuleInfo ( key : " ffw_layer_1 " ) var ffwLayer1 : Linear
2676
+ @ModuleInfo ( key : " ffw_layer_2 " ) var ffwLayer2 : Linear
2677
+ @ModuleInfo ( key : " post_layer_norm " ) var postLayerNorm : Gemma3nRMSNormWithScale
2677
2678
2678
2679
init ( config: AudioConfig ) {
2679
2680
self . config = config
2680
- self . gradientClipping = MLXArray ( config. gradientClipping)
2681
- self . postLayerScale = MLXArray ( config. confResidualWeight)
2681
+ self . _gradientClipping = MLXArray ( config. gradientClipping)
2682
+ self . _postLayerScale = MLXArray ( config. confResidualWeight)
2682
2683
2683
2684
self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2684
2685
self . _ffwLayer1. wrappedValue = Linear ( config. hiddenSize, config. hiddenSize * 4 , bias: false )
@@ -2690,32 +2691,32 @@ private class Gemma3nAudioConformerFeedForward: Module {
2690
2691
2691
2692
func callAsFunction( _ x: MLXArray ) -> MLXArray {
2692
2693
let residual = x
2693
- let clippedX = clip ( x, min: - gradientClipping , max: gradientClipping )
2694
+ let clippedX = clip ( x, min: - _gradientClipping , max: _gradientClipping )
2694
2695
var result = preLayerNorm ( clippedX)
2695
2696
result = ffwLayer1 ( result)
2696
2697
result = silu ( result)
2697
2698
result = ffwLayer2 ( result)
2698
- let clippedResult = clip ( result, min: - gradientClipping , max: gradientClipping )
2699
+ let clippedResult = clip ( result, min: - _gradientClipping , max: _gradientClipping )
2699
2700
let normedResult = postLayerNorm ( clippedResult)
2700
- return residual + ( normedResult * postLayerScale )
2701
+ return residual + ( normedResult * _postLayerScale )
2701
2702
}
2702
2703
}
2703
2704
2704
2705
// MARK: - Conformer Light Conv1D
2705
2706
private class Gemma3nAudioConformerLightConv1d : Module {
2706
2707
let config : AudioConfig
2707
- private let gradientClipping : MLXArray
2708
+ private let _gradientClipping : MLXArray
2708
2709
let causalPadding : Int
2709
2710
2710
- @ModuleInfo var preLayerNorm : Gemma3nRMSNormWithScale
2711
- @ModuleInfo var linearStart : Linear
2712
- @ModuleInfo var depthwiseConv1d : Conv1d
2713
- @ModuleInfo var convNorm : Gemma3nRMSNormWithScale
2714
- @ModuleInfo var linearEnd : Linear
2711
+ @ModuleInfo ( key : " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2712
+ @ModuleInfo ( key : " linear_start " ) var linearStart : Linear
2713
+ @ModuleInfo ( key : " depthwise_conv1d " ) var depthwiseConv1d : Conv1d
2714
+ @ModuleInfo ( key : " conv_norm " ) var convNorm : Gemma3nRMSNormWithScale
2715
+ @ModuleInfo ( key : " linear_end " ) var linearEnd : Linear
2715
2716
2716
2717
init ( config: AudioConfig ) {
2717
2718
self . config = config
2718
- self . gradientClipping = MLXArray ( config. gradientClipping)
2719
+ self . _gradientClipping = MLXArray ( config. gradientClipping)
2719
2720
self . causalPadding = config. confConvKernelSize - 1
2720
2721
2721
2722
self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale (
@@ -2761,7 +2762,7 @@ private class Gemma3nAudioConformerLightConv1d: Module {
2761
2762
)
2762
2763
2763
2764
result = depthwiseConv1d ( paddedAudio. transposed ( 0 , 2 , 1 ) )
2764
- result = clip ( result, min: - gradientClipping , max: gradientClipping )
2765
+ result = clip ( result, min: - _gradientClipping , max: _gradientClipping )
2765
2766
result = convNorm ( result)
2766
2767
result = silu ( result)
2767
2768
result = linearEnd ( result)
@@ -2967,11 +2968,11 @@ private class ConvNormAct: Module, UnaryLayer {
2967
2968
// MARK: - Universal Inverted Residual
2968
2969
private class UniversalInvertedResidual : Module , UnaryLayer {
2969
2970
let hasSkip : Bool
2970
- @ModuleInfo var dwStart : UnaryLayer
2971
- @ModuleInfo var pwExp : ConvNormAct
2972
- @ModuleInfo var dwMid : UnaryLayer
2973
- @ModuleInfo var pwProj : ConvNormAct
2974
- @ModuleInfo var layerScale : UnaryLayer
2971
+ @ModuleInfo ( key : " dw_start " ) var dwStart : UnaryLayer
2972
+ @ModuleInfo ( key : " pw_exp " ) var pwExp : ConvNormAct
2973
+ @ModuleInfo ( key : " dw_mid " ) var dwMid : UnaryLayer
2974
+ @ModuleInfo ( key : " pw_proj " ) var pwProj : ConvNormAct
2975
+ @ModuleInfo ( key : " layer_scale " ) var layerScale : UnaryLayer
2975
2976
2976
2977
init (
2977
2978
inChannels: Int ,
@@ -3088,9 +3089,9 @@ private class UniversalInvertedResidual: Module, UnaryLayer {
3088
3089
// MARK: - Edge Residual
3089
3090
private class EdgeResidual : Module , UnaryLayer {
3090
3091
let hasSkip : Bool
3091
- @ModuleInfo var convExp : Conv2d
3092
+ @ModuleInfo ( key : " conv_exp " ) var convExp : Conv2d
3092
3093
@ModuleInfo var bn1 : RMSNormAct2d
3093
- @ModuleInfo var convPwl : Conv2d
3094
+ @ModuleInfo ( key : " conv_pwl " ) var convPwl : Conv2d
3094
3095
@ModuleInfo var bn2 : RMSNormAct2d
3095
3096
3096
3097
init (
@@ -3184,9 +3185,9 @@ private class MultiQueryAttention2d: Module {
3184
3185
3185
3186
@ModuleInfo var keyProj : Conv2d
3186
3187
@ModuleInfo var valueProj : Conv2d
3187
- @ModuleInfo var attnDrop : UnaryLayer
3188
+ @ModuleInfo ( key : " attn_drop " ) var attnDrop : UnaryLayer
3188
3189
@ModuleInfo var outputProj : Conv2d
3189
- @ModuleInfo var projDrop : UnaryLayer
3190
+ @ModuleInfo ( key : " proj_drop " ) var projDrop : UnaryLayer
3190
3191
3191
3192
init (
3192
3193
dim: Int ,
@@ -3681,7 +3682,7 @@ private class MobileNetV5MultiScaleFusionAdapter: Module {
3681
3682
3682
3683
// MARK: - Vision Tower
3683
3684
private class VisionTower : Module {
3684
- @ModuleInfo var convStem : ConvNormAct
3685
+ @ModuleInfo ( key : " conv_stem " ) var convStem : ConvNormAct
3685
3686
@ModuleInfo var blocks : [ [ UnaryLayer ] ]
3686
3687
@ModuleInfo var msfa : MobileNetV5MultiScaleFusionAdapter
3687
3688
@@ -3800,7 +3801,7 @@ private class VisionTower: Module {
3800
3801
// MARK: - Complete Vision Model
3801
3802
private class Gemma3nVisionModel : Module {
3802
3803
let modelType : String
3803
- @ModuleInfo var timmModel : VisionTower
3804
+ @ModuleInfo ( key : " timm_model " ) var timmModel : VisionTower
3804
3805
3805
3806
init ( config: VisionConfig ) {
3806
3807
self . modelType = config. modelType
@@ -3856,7 +3857,8 @@ private class Gemma3nVisionModel: Module {
3856
3857
private class Gemma3nAudioModel : Module {
3857
3858
let config : AudioConfig
3858
3859
3859
- @ModuleInfo var subsampleConvProjection : Gemma3nAudioSubSampleConvProjection
3860
+ @ModuleInfo ( key: " subsample_conv_projection " ) var subsampleConvProjection :
3861
+ Gemma3nAudioSubSampleConvProjection
3860
3862
@ModuleInfo var conformer : [ Gemma3nAudioConformerBlock ]
3861
3863
3862
3864
init ( config: AudioConfig ) {
0 commit comments