@@ -386,76 +386,42 @@ public struct ModelConfig: Codable, Sendable {
386
386
387
387
// MARK: - Language Model Components
388
388
389
- // Base protocol for RMSNorm variants
390
- private protocol Gemma3nRMSNormProtocol : UnaryLayer {
391
- func callAsFunction( _ x: MLXArray ) -> MLXArray
392
- }
393
-
394
- // RMSNorm with scale parameter
395
- private class Gemma3nRMSNormWithScale : Module , Gemma3nRMSNormProtocol {
389
+ private class Gemma3nRMSNorm : Module {
396
390
let eps : Float
397
- let scaleShift : Float
398
- let weight : MLXArray
391
+ let scaleShift : Float ?
392
+ let weight : MLXArray ?
399
393
400
- init ( dim: Int , eps: Float = 1e-6 , scaleShift: Float = 0.0 ) {
394
+ init ( dim: Int , eps: Float = 1e-6 , scaleShift: Float ? = nil ) {
401
395
self . eps = eps
402
396
self . scaleShift = scaleShift
403
- self . weight = MLXArray . ones ( [ dim] )
397
+ self . weight = scaleShift != nil ? MLXArray . ones ( [ dim] ) : nil
404
398
super. init ( )
405
399
}
406
400
407
401
func callAsFunction( _ x: MLXArray ) -> MLXArray {
408
402
let output = norm ( x. asType ( . float32) )
409
- return ( output * ( weight + scaleShift) ) . asType ( x. dtype)
410
- }
411
403
412
- private func norm( _ x: MLXArray ) -> MLXArray {
413
- return x * rsqrt( x. square ( ) . mean ( axis: - 1 , keepDims: true ) + eps)
414
- }
415
- }
416
-
417
- // RMSNorm without scale parameter (no weight to load from checkpoint)
418
- private class Gemma3nRMSNormNoScale : Module , Gemma3nRMSNormProtocol {
419
- let eps : Float
420
-
421
- init ( dim: Int , eps: Float = 1e-6 ) {
422
- self . eps = eps
423
- super. init ( )
424
- }
425
-
426
- func callAsFunction( _ x: MLXArray ) -> MLXArray {
427
- let output = norm ( x. asType ( . float32) )
428
- return output. asType ( x. dtype)
404
+ if let weight, let scaleShift {
405
+ return ( output * ( weight + scaleShift) ) . asType ( x. dtype)
406
+ } else {
407
+ return output. asType ( x. dtype)
408
+ }
429
409
}
430
410
431
411
private func norm( _ x: MLXArray ) -> MLXArray {
432
412
return x * rsqrt( x. square ( ) . mean ( axis: - 1 , keepDims: true ) + eps)
433
413
}
434
414
}
435
415
436
- // Factory function to create the appropriate RMSNorm variant
437
- private func createGemma3nRMSNorm(
438
- dim: Int ,
439
- eps: Float = 1e-6 ,
440
- scaleShift: Float = 0.0 ,
441
- withScale: Bool = true
442
- ) -> any Gemma3nRMSNormProtocol {
443
- if withScale {
444
- return Gemma3nRMSNormWithScale ( dim: dim, eps: eps, scaleShift: scaleShift)
445
- } else {
446
- return Gemma3nRMSNormNoScale ( dim: dim, eps: eps)
447
- }
448
- }
449
-
450
416
private class Gemma3nLaurelBlock : Module {
451
417
@ModuleInfo ( key: " linear_left " ) var linearLeft : Linear
452
418
@ModuleInfo ( key: " linear_right " ) var linearRight : Linear
453
- @ModuleInfo ( key: " post_laurel_norm " ) var postLaurelNorm : Gemma3nRMSNormWithScale
419
+ @ModuleInfo ( key: " post_laurel_norm " ) var postLaurelNorm : Gemma3nRMSNorm
454
420
455
421
init ( config: TextConfig ) {
456
422
self . _linearLeft. wrappedValue = Linear ( config. hiddenSize, config. laurelRank, bias: false )
457
423
self . _linearRight. wrappedValue = Linear ( config. laurelRank, config. hiddenSize, bias: false )
458
- self . _postLaurelNorm. wrappedValue = Gemma3nRMSNormWithScale (
424
+ self . _postLaurelNorm. wrappedValue = Gemma3nRMSNorm (
459
425
dim: config. hiddenSize,
460
426
eps: config. rmsNormEps,
461
427
scaleShift: 0.0
@@ -570,9 +536,9 @@ private class Gemma3nAttention: Module {
570
536
@ModuleInfo ( key: " k_proj " ) var kProj : Linear
571
537
@ModuleInfo ( key: " v_proj " ) var vProj : Linear
572
538
@ModuleInfo ( key: " o_proj " ) var oProj : Linear
573
- @ModuleInfo ( key: " q_norm " ) var qNorm : Gemma3nRMSNormWithScale
574
- @ModuleInfo ( key: " k_norm " ) var kNorm : Gemma3nRMSNormWithScale
575
- @ModuleInfo ( key: " v_norm " ) var vNorm : Gemma3nRMSNormNoScale
539
+ @ModuleInfo ( key: " q_norm " ) var qNorm : Gemma3nRMSNorm
540
+ @ModuleInfo ( key: " k_norm " ) var kNorm : Gemma3nRMSNorm
541
+ @ModuleInfo ( key: " v_norm " ) var vNorm : Gemma3nRMSNorm
576
542
577
543
init ( config: TextConfig , layerIdx: Int ) {
578
544
self . isSliding =
@@ -594,11 +560,11 @@ private class Gemma3nAttention: Module {
594
560
self . _vProj. wrappedValue = Linear ( dim, numKVHeads * headDim, bias: false )
595
561
self . _oProj. wrappedValue = Linear ( numHeads * headDim, dim, bias: false )
596
562
597
- self . _qNorm. wrappedValue = Gemma3nRMSNormWithScale (
563
+ self . _qNorm. wrappedValue = Gemma3nRMSNorm (
598
564
dim: config. headDim, eps: config. rmsNormEps)
599
- self . _kNorm. wrappedValue = Gemma3nRMSNormWithScale (
565
+ self . _kNorm. wrappedValue = Gemma3nRMSNorm (
600
566
dim: config. headDim, eps: config. rmsNormEps)
601
- self . _vNorm. wrappedValue = Gemma3nRMSNormNoScale (
567
+ self . _vNorm. wrappedValue = Gemma3nRMSNorm (
602
568
dim: config. headDim,
603
569
eps: config. rmsNormEps
604
570
)
@@ -749,7 +715,7 @@ private class Gemma3nAltUp: Module {
749
715
@ModuleInfo ( key: " correction_coefs " ) var correctionCoefs : Linear
750
716
@ModuleInfo ( key: " prediction_coefs " ) var predictionCoefs : Linear
751
717
@ModuleInfo ( key: " modality_router " ) var modalityRouter : Linear
752
- @ModuleInfo ( key: " router_norm " ) var routerNorm : Gemma3nRMSNormWithScale
718
+ @ModuleInfo ( key: " router_norm " ) var routerNorm : Gemma3nRMSNorm
753
719
private let _routerInputScale : MLXArray
754
720
755
721
let config : TextConfig
@@ -773,7 +739,7 @@ private class Gemma3nAltUp: Module {
773
739
config. altupNumInputs,
774
740
bias: false
775
741
)
776
- self . _routerNorm. wrappedValue = Gemma3nRMSNormWithScale (
742
+ self . _routerNorm. wrappedValue = Gemma3nRMSNorm (
777
743
dim: config. hiddenSize,
778
744
eps: config. rmsNormEps,
779
745
scaleShift: 0.0
@@ -784,8 +750,13 @@ private class Gemma3nAltUp: Module {
784
750
}
785
751
786
752
func computeRouterModalities( _ x: MLXArray ) -> MLXArray {
787
- let routerInputs =
788
- routerNorm ( x) * _routerInputScale. asType ( routerNorm. weight. dtype)
753
+ guard let routerNormWeight = routerNorm. weight else {
754
+ // This should never happen, since `routerNorm` is assigned `Gemma3nRMSNorm` with `scaleShift`, so `weight` should not be nil
755
+ fatalError ( " routerNorm.weight is nil " )
756
+ }
757
+
758
+ let routerInputs = routerNorm ( x) * _routerInputScale. asType ( routerNormWeight. dtype)
759
+
789
760
let routed = modalityRouter ( routerInputs) . asType ( . float32)
790
761
return tanh ( routed)
791
762
}
@@ -875,17 +846,15 @@ private class Gemma3nDecoderLayer: Module {
875
846
876
847
@ModuleInfo ( key: " self_attn " ) var selfAttn : Gemma3nAttention
877
848
@ModuleInfo var mlp : MLP
878
- @ModuleInfo ( key: " input_layernorm " ) var inputLayernorm : Gemma3nRMSNormWithScale
879
- @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayernorm : Gemma3nRMSNormWithScale
880
- @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayernorm :
881
- Gemma3nRMSNormWithScale
882
- @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayernorm :
883
- Gemma3nRMSNormWithScale
849
+ @ModuleInfo ( key: " input_layernorm " ) var inputLayernorm : Gemma3nRMSNorm
850
+ @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayernorm : Gemma3nRMSNorm
851
+ @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayernorm : Gemma3nRMSNorm
852
+ @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayernorm : Gemma3nRMSNorm
884
853
@ModuleInfo var altup : Gemma3nAltUp
885
854
@ModuleInfo var laurel : Gemma3nLaurelBlock
886
855
@ModuleInfo ( key: " per_layer_input_gate " ) var perLayerInputGate : Linear
887
856
@ModuleInfo ( key: " per_layer_projection " ) var perLayerProjection : Linear
888
- @ModuleInfo ( key: " post_per_layer_input_norm " ) var postPerLayerInputNorm : Gemma3nRMSNormWithScale
857
+ @ModuleInfo ( key: " post_per_layer_input_norm " ) var postPerLayerInputNorm : Gemma3nRMSNorm
889
858
890
859
init ( config: TextConfig , layerIdx: Int ) {
891
860
self . config = config
@@ -901,23 +870,23 @@ private class Gemma3nDecoderLayer: Module {
901
870
== " sliding_attention "
902
871
903
872
self . _mlp. wrappedValue = MLP ( config: config, layerIdx: layerIdx)
904
- self . _inputLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
873
+ self . _inputLayernorm. wrappedValue = Gemma3nRMSNorm (
905
874
dim: hiddenSize,
906
875
eps: config. rmsNormEps,
907
876
scaleShift: 0.0
908
877
)
909
878
910
- self . _postAttentionLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
879
+ self . _postAttentionLayernorm. wrappedValue = Gemma3nRMSNorm (
911
880
dim: hiddenSize,
912
881
eps: config. rmsNormEps,
913
882
scaleShift: 0.0
914
883
)
915
- self . _preFeedforwardLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
884
+ self . _preFeedforwardLayernorm. wrappedValue = Gemma3nRMSNorm (
916
885
dim: hiddenSize,
917
886
eps: config. rmsNormEps,
918
887
scaleShift: 0.0
919
888
)
920
- self . _postFeedforwardLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
889
+ self . _postFeedforwardLayernorm. wrappedValue = Gemma3nRMSNorm (
921
890
dim: hiddenSize,
922
891
eps: config. rmsNormEps,
923
892
scaleShift: 0.0
@@ -936,7 +905,7 @@ private class Gemma3nDecoderLayer: Module {
936
905
hiddenSize,
937
906
bias: false
938
907
)
939
- self . _postPerLayerInputNorm. wrappedValue = Gemma3nRMSNormWithScale (
908
+ self . _postPerLayerInputNorm. wrappedValue = Gemma3nRMSNorm (
940
909
dim: hiddenSize,
941
910
eps: config. rmsNormEps,
942
911
scaleShift: 0.0
@@ -1049,13 +1018,12 @@ private class Gemma3Model: Module {
1049
1018
@ModuleInfo ( key: " layers " ) var layers : [ Gemma3nDecoderLayer ]
1050
1019
@ModuleInfo ( key: " embed_tokens_per_layer " ) var embedTokensPerLayer : Embedding
1051
1020
@ModuleInfo ( key: " per_layer_model_projection " ) var perLayerModelProjection : Linear
1052
- @ModuleInfo ( key: " per_layer_projection_norm " ) var perLayerProjectionNorm :
1053
- Gemma3nRMSNormWithScale
1021
+ @ModuleInfo ( key: " per_layer_projection_norm " ) var perLayerProjectionNorm : Gemma3nRMSNorm
1054
1022
1055
1023
@ModuleInfo ( key: " altup_projections " ) var altupProjections : [ Linear ]
1056
1024
@ModuleInfo ( key: " altup_unembed_projections " ) var altupUnembedProjections : [ Linear ]
1057
1025
1058
- @ModuleInfo var norm : Gemma3nRMSNormWithScale
1026
+ @ModuleInfo var norm : Gemma3nRMSNorm
1059
1027
@ModuleInfo ( key: " rope_embedding " ) var ropeEmbedding : Gemma3nRotaryEmbedding
1060
1028
@ModuleInfo ( key: " rope_embedding_local " ) var ropeEmbeddingLocal : Gemma3nRotaryEmbedding
1061
1029
@@ -1090,7 +1058,7 @@ private class Gemma3Model: Module {
1090
1058
bias: false
1091
1059
)
1092
1060
1093
- self . _perLayerProjectionNorm. wrappedValue = Gemma3nRMSNormWithScale (
1061
+ self . _perLayerProjectionNorm. wrappedValue = Gemma3nRMSNorm (
1094
1062
dim: config. hiddenSizePerLayerInput,
1095
1063
eps: config. rmsNormEps,
1096
1064
scaleShift: 0.0
@@ -1103,7 +1071,7 @@ private class Gemma3Model: Module {
1103
1071
Linear ( config. hiddenSize, config. hiddenSize, bias: false )
1104
1072
}
1105
1073
1106
- self . _norm. wrappedValue = Gemma3nRMSNormWithScale (
1074
+ self . _norm. wrappedValue = Gemma3nRMSNorm (
1107
1075
dim: config. hiddenSize,
1108
1076
eps: config. rmsNormEps,
1109
1077
scaleShift: 0.0
@@ -1375,11 +1343,11 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
1375
1343
let textHiddenSize : Int
1376
1344
1377
1345
@ModuleInfo var embedding : Embedding
1378
- @ModuleInfo ( key: " hard_embedding_norm " ) var hardEmbeddingNorm : Gemma3nRMSNormWithScale
1379
- @ModuleInfo ( key: " soft_embedding_norm " ) var softEmbeddingNorm : Gemma3nRMSNormWithScale
1346
+ @ModuleInfo ( key: " hard_embedding_norm " ) var hardEmbeddingNorm : Gemma3nRMSNorm
1347
+ @ModuleInfo ( key: " soft_embedding_norm " ) var softEmbeddingNorm : Gemma3nRMSNorm
1380
1348
@ModuleInfo ( key: " embedding_projection " ) var embeddingProjection : Linear
1381
1349
@ModuleInfo ( key: " embedding_post_projection_norm " ) var embeddingPostProjectionNorm :
1382
- Gemma3nRMSNormNoScale
1350
+ Gemma3nRMSNorm
1383
1351
1384
1352
init ( multimodalConfig: any MultimodalConfig , textConfig: TextConfig ) {
1385
1353
self . multimodalHiddenSize = multimodalConfig. hiddenSize
@@ -1392,11 +1360,11 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
1392
1360
embeddingCount: vocabSize,
1393
1361
dimensions: multimodalHiddenSize
1394
1362
)
1395
- self . _hardEmbeddingNorm. wrappedValue = Gemma3nRMSNormWithScale (
1363
+ self . _hardEmbeddingNorm. wrappedValue = Gemma3nRMSNorm (
1396
1364
dim: multimodalHiddenSize,
1397
1365
eps: eps
1398
1366
)
1399
- self . _softEmbeddingNorm. wrappedValue = Gemma3nRMSNormWithScale (
1367
+ self . _softEmbeddingNorm. wrappedValue = Gemma3nRMSNorm (
1400
1368
dim: multimodalHiddenSize,
1401
1369
eps: eps
1402
1370
)
@@ -1405,7 +1373,7 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
1405
1373
textHiddenSize,
1406
1374
bias: false
1407
1375
)
1408
- self . _embeddingPostProjectionNorm. wrappedValue = Gemma3nRMSNormNoScale (
1376
+ self . _embeddingPostProjectionNorm. wrappedValue = Gemma3nRMSNorm (
1409
1377
dim: textHiddenSize,
1410
1378
eps: eps
1411
1379
)
@@ -2538,21 +2506,21 @@ private class Gemma3nAudioConformerAttention: Module {
2538
2506
let postInFeatures : Int
2539
2507
private let _gradientClipping : MLXArray
2540
2508
2541
- @ModuleInfo ( key: " pre_attn_norm " ) var preAttnNorm : Gemma3nRMSNormWithScale
2509
+ @ModuleInfo ( key: " pre_attn_norm " ) var preAttnNorm : Gemma3nRMSNorm
2542
2510
@ModuleInfo var attn : Gemma3nAudioAttention
2543
2511
@ModuleInfo var post : Linear
2544
- @ModuleInfo ( key: " post_norm " ) var postNorm : Gemma3nRMSNormWithScale
2512
+ @ModuleInfo ( key: " post_norm " ) var postNorm : Gemma3nRMSNorm
2545
2513
2546
2514
init ( config: AudioConfig ) {
2547
2515
self . config = config
2548
2516
let headDim = config. hiddenSize / config. confNumAttentionHeads
2549
2517
self . postInFeatures = config. hiddenSize
2550
2518
self . _gradientClipping = MLXArray ( config. gradientClipping)
2551
2519
2552
- self . _preAttnNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2520
+ self . _preAttnNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
2553
2521
self . _attn. wrappedValue = Gemma3nAudioAttention ( config: config)
2554
2522
self . _post. wrappedValue = Linear ( postInFeatures, config. hiddenSize, bias: false )
2555
- self . _postNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2523
+ self . _postNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
2556
2524
2557
2525
super. init ( )
2558
2526
}
@@ -2581,20 +2549,20 @@ private class Gemma3nAudioConformerFeedForward: Module {
2581
2549
private let _gradientClipping : MLXArray
2582
2550
private let _postLayerScale : MLXArray
2583
2551
2584
- @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2552
+ @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNorm
2585
2553
@ModuleInfo ( key: " ffw_layer_1 " ) var ffwLayer1 : Linear
2586
2554
@ModuleInfo ( key: " ffw_layer_2 " ) var ffwLayer2 : Linear
2587
- @ModuleInfo ( key: " post_layer_norm " ) var postLayerNorm : Gemma3nRMSNormWithScale
2555
+ @ModuleInfo ( key: " post_layer_norm " ) var postLayerNorm : Gemma3nRMSNorm
2588
2556
2589
2557
init ( config: AudioConfig ) {
2590
2558
self . config = config
2591
2559
self . _gradientClipping = MLXArray ( config. gradientClipping)
2592
2560
self . _postLayerScale = MLXArray ( config. confResidualWeight)
2593
2561
2594
- self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2562
+ self . _preLayerNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
2595
2563
self . _ffwLayer1. wrappedValue = Linear ( config. hiddenSize, config. hiddenSize * 4 , bias: false )
2596
2564
self . _ffwLayer2. wrappedValue = Linear ( config. hiddenSize * 4 , config. hiddenSize, bias: false )
2597
- self . _postLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2565
+ self . _postLayerNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
2598
2566
2599
2567
super. init ( )
2600
2568
}
@@ -2618,18 +2586,18 @@ private class Gemma3nAudioConformerLightConv1d: Module {
2618
2586
private let _gradientClipping : MLXArray
2619
2587
let causalPadding : Int
2620
2588
2621
- @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2589
+ @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNorm
2622
2590
@ModuleInfo ( key: " linear_start " ) var linearStart : Linear
2623
2591
@ModuleInfo ( key: " depthwise_conv1d " ) var depthwiseConv1d : Conv1d
2624
- @ModuleInfo ( key: " conv_norm " ) var convNorm : Gemma3nRMSNormWithScale
2592
+ @ModuleInfo ( key: " conv_norm " ) var convNorm : Gemma3nRMSNorm
2625
2593
@ModuleInfo ( key: " linear_end " ) var linearEnd : Linear
2626
2594
2627
2595
init ( config: AudioConfig ) {
2628
2596
self . config = config
2629
2597
self . _gradientClipping = MLXArray ( config. gradientClipping)
2630
2598
self . causalPadding = config. confConvKernelSize - 1
2631
2599
2632
- self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale (
2600
+ self . _preLayerNorm. wrappedValue = Gemma3nRMSNorm (
2633
2601
dim: config. hiddenSize,
2634
2602
eps: config. rmsNormEps
2635
2603
)
@@ -2647,7 +2615,7 @@ private class Gemma3nAudioConformerLightConv1d: Module {
2647
2615
groups: config. hiddenSize,
2648
2616
bias: false
2649
2617
)
2650
- self . _convNorm. wrappedValue = Gemma3nRMSNormWithScale (
2618
+ self . _convNorm. wrappedValue = Gemma3nRMSNorm (
2651
2619
dim: config. hiddenSize,
2652
2620
eps: config. rmsNormEps
2653
2621
)
@@ -2690,7 +2658,7 @@ private class Gemma3nAudioConformerBlock: Module {
2690
2658
@ModuleInfo var attention : Gemma3nAudioConformerAttention
2691
2659
@ModuleInfo var lconv1d : Gemma3nAudioConformerLightConv1d
2692
2660
@ModuleInfo ( key: " ffw_layer_end " ) var ffwLayerEnd : Gemma3nAudioConformerFeedForward
2693
- @ModuleInfo var norm : Gemma3nRMSNormWithScale
2661
+ @ModuleInfo var norm : Gemma3nRMSNorm
2694
2662
2695
2663
init ( config: AudioConfig ) {
2696
2664
self . config = config
@@ -2700,7 +2668,7 @@ private class Gemma3nAudioConformerBlock: Module {
2700
2668
self . _attention. wrappedValue = Gemma3nAudioConformerAttention ( config: config)
2701
2669
self . _lconv1d. wrappedValue = Gemma3nAudioConformerLightConv1d ( config: config)
2702
2670
self . _ffwLayerEnd. wrappedValue = Gemma3nAudioConformerFeedForward ( config: config)
2703
- self . _norm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2671
+ self . _norm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
2704
2672
2705
2673
super. init ( )
2706
2674
}
0 commit comments