Skip to content

Commit 99bdc09

Browse files
committed
Clean up Gemma3nRMSNorm
1 parent 6aa0834 commit 99bdc09

File tree

1 file changed

+61
-93
lines changed

1 file changed

+61
-93
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 61 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -386,76 +386,42 @@ public struct ModelConfig: Codable, Sendable {
386386

387387
// MARK: - Language Model Components
388388

389-
// Base protocol for RMSNorm variants
390-
private protocol Gemma3nRMSNormProtocol: UnaryLayer {
391-
func callAsFunction(_ x: MLXArray) -> MLXArray
392-
}
393-
394-
// RMSNorm with scale parameter
395-
private class Gemma3nRMSNormWithScale: Module, Gemma3nRMSNormProtocol {
389+
private class Gemma3nRMSNorm: Module {
396390
let eps: Float
397-
let scaleShift: Float
398-
let weight: MLXArray
391+
let scaleShift: Float?
392+
let weight: MLXArray?
399393

400-
init(dim: Int, eps: Float = 1e-6, scaleShift: Float = 0.0) {
394+
init(dim: Int, eps: Float = 1e-6, scaleShift: Float? = nil) {
401395
self.eps = eps
402396
self.scaleShift = scaleShift
403-
self.weight = MLXArray.ones([dim])
397+
self.weight = scaleShift != nil ? MLXArray.ones([dim]) : nil
404398
super.init()
405399
}
406400

407401
func callAsFunction(_ x: MLXArray) -> MLXArray {
408402
let output = norm(x.asType(.float32))
409-
return (output * (weight + scaleShift)).asType(x.dtype)
410-
}
411403

412-
private func norm(_ x: MLXArray) -> MLXArray {
413-
return x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps)
414-
}
415-
}
416-
417-
// RMSNorm without scale parameter (no weight to load from checkpoint)
418-
private class Gemma3nRMSNormNoScale: Module, Gemma3nRMSNormProtocol {
419-
let eps: Float
420-
421-
init(dim: Int, eps: Float = 1e-6) {
422-
self.eps = eps
423-
super.init()
424-
}
425-
426-
func callAsFunction(_ x: MLXArray) -> MLXArray {
427-
let output = norm(x.asType(.float32))
428-
return output.asType(x.dtype)
404+
if let weight, let scaleShift {
405+
return (output * (weight + scaleShift)).asType(x.dtype)
406+
} else {
407+
return output.asType(x.dtype)
408+
}
429409
}
430410

431411
private func norm(_ x: MLXArray) -> MLXArray {
432412
return x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps)
433413
}
434414
}
435415

436-
// Factory function to create the appropriate RMSNorm variant
437-
private func createGemma3nRMSNorm(
438-
dim: Int,
439-
eps: Float = 1e-6,
440-
scaleShift: Float = 0.0,
441-
withScale: Bool = true
442-
) -> any Gemma3nRMSNormProtocol {
443-
if withScale {
444-
return Gemma3nRMSNormWithScale(dim: dim, eps: eps, scaleShift: scaleShift)
445-
} else {
446-
return Gemma3nRMSNormNoScale(dim: dim, eps: eps)
447-
}
448-
}
449-
450416
private class Gemma3nLaurelBlock: Module {
451417
@ModuleInfo(key: "linear_left") var linearLeft: Linear
452418
@ModuleInfo(key: "linear_right") var linearRight: Linear
453-
@ModuleInfo(key: "post_laurel_norm") var postLaurelNorm: Gemma3nRMSNormWithScale
419+
@ModuleInfo(key: "post_laurel_norm") var postLaurelNorm: Gemma3nRMSNorm
454420

455421
init(config: TextConfig) {
456422
self._linearLeft.wrappedValue = Linear(config.hiddenSize, config.laurelRank, bias: false)
457423
self._linearRight.wrappedValue = Linear(config.laurelRank, config.hiddenSize, bias: false)
458-
self._postLaurelNorm.wrappedValue = Gemma3nRMSNormWithScale(
424+
self._postLaurelNorm.wrappedValue = Gemma3nRMSNorm(
459425
dim: config.hiddenSize,
460426
eps: config.rmsNormEps,
461427
scaleShift: 0.0
@@ -570,9 +536,9 @@ private class Gemma3nAttention: Module {
570536
@ModuleInfo(key: "k_proj") var kProj: Linear
571537
@ModuleInfo(key: "v_proj") var vProj: Linear
572538
@ModuleInfo(key: "o_proj") var oProj: Linear
573-
@ModuleInfo(key: "q_norm") var qNorm: Gemma3nRMSNormWithScale
574-
@ModuleInfo(key: "k_norm") var kNorm: Gemma3nRMSNormWithScale
575-
@ModuleInfo(key: "v_norm") var vNorm: Gemma3nRMSNormNoScale
539+
@ModuleInfo(key: "q_norm") var qNorm: Gemma3nRMSNorm
540+
@ModuleInfo(key: "k_norm") var kNorm: Gemma3nRMSNorm
541+
@ModuleInfo(key: "v_norm") var vNorm: Gemma3nRMSNorm
576542

577543
init(config: TextConfig, layerIdx: Int) {
578544
self.isSliding =
@@ -594,11 +560,11 @@ private class Gemma3nAttention: Module {
594560
self._vProj.wrappedValue = Linear(dim, numKVHeads * headDim, bias: false)
595561
self._oProj.wrappedValue = Linear(numHeads * headDim, dim, bias: false)
596562

597-
self._qNorm.wrappedValue = Gemma3nRMSNormWithScale(
563+
self._qNorm.wrappedValue = Gemma3nRMSNorm(
598564
dim: config.headDim, eps: config.rmsNormEps)
599-
self._kNorm.wrappedValue = Gemma3nRMSNormWithScale(
565+
self._kNorm.wrappedValue = Gemma3nRMSNorm(
600566
dim: config.headDim, eps: config.rmsNormEps)
601-
self._vNorm.wrappedValue = Gemma3nRMSNormNoScale(
567+
self._vNorm.wrappedValue = Gemma3nRMSNorm(
602568
dim: config.headDim,
603569
eps: config.rmsNormEps
604570
)
@@ -749,7 +715,7 @@ private class Gemma3nAltUp: Module {
749715
@ModuleInfo(key: "correction_coefs") var correctionCoefs: Linear
750716
@ModuleInfo(key: "prediction_coefs") var predictionCoefs: Linear
751717
@ModuleInfo(key: "modality_router") var modalityRouter: Linear
752-
@ModuleInfo(key: "router_norm") var routerNorm: Gemma3nRMSNormWithScale
718+
@ModuleInfo(key: "router_norm") var routerNorm: Gemma3nRMSNorm
753719
private let _routerInputScale: MLXArray
754720

755721
let config: TextConfig
@@ -773,7 +739,7 @@ private class Gemma3nAltUp: Module {
773739
config.altupNumInputs,
774740
bias: false
775741
)
776-
self._routerNorm.wrappedValue = Gemma3nRMSNormWithScale(
742+
self._routerNorm.wrappedValue = Gemma3nRMSNorm(
777743
dim: config.hiddenSize,
778744
eps: config.rmsNormEps,
779745
scaleShift: 0.0
@@ -784,8 +750,13 @@ private class Gemma3nAltUp: Module {
784750
}
785751

786752
func computeRouterModalities(_ x: MLXArray) -> MLXArray {
787-
let routerInputs =
788-
routerNorm(x) * _routerInputScale.asType(routerNorm.weight.dtype)
753+
guard let routerNormWeight = routerNorm.weight else {
754+
// This should never happen, since `routerNorm` is assigned `Gemma3nRMSNorm` with `scaleShift`, so `weight` should not be nil
755+
fatalError("routerNorm.weight is nil")
756+
}
757+
758+
let routerInputs = routerNorm(x) * _routerInputScale.asType(routerNormWeight.dtype)
759+
789760
let routed = modalityRouter(routerInputs).asType(.float32)
790761
return tanh(routed)
791762
}
@@ -875,17 +846,15 @@ private class Gemma3nDecoderLayer: Module {
875846

876847
@ModuleInfo(key: "self_attn") var selfAttn: Gemma3nAttention
877848
@ModuleInfo var mlp: MLP
878-
@ModuleInfo(key: "input_layernorm") var inputLayernorm: Gemma3nRMSNormWithScale
879-
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: Gemma3nRMSNormWithScale
880-
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayernorm:
881-
Gemma3nRMSNormWithScale
882-
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayernorm:
883-
Gemma3nRMSNormWithScale
849+
@ModuleInfo(key: "input_layernorm") var inputLayernorm: Gemma3nRMSNorm
850+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: Gemma3nRMSNorm
851+
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayernorm: Gemma3nRMSNorm
852+
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayernorm: Gemma3nRMSNorm
884853
@ModuleInfo var altup: Gemma3nAltUp
885854
@ModuleInfo var laurel: Gemma3nLaurelBlock
886855
@ModuleInfo(key: "per_layer_input_gate") var perLayerInputGate: Linear
887856
@ModuleInfo(key: "per_layer_projection") var perLayerProjection: Linear
888-
@ModuleInfo(key: "post_per_layer_input_norm") var postPerLayerInputNorm: Gemma3nRMSNormWithScale
857+
@ModuleInfo(key: "post_per_layer_input_norm") var postPerLayerInputNorm: Gemma3nRMSNorm
889858

890859
init(config: TextConfig, layerIdx: Int) {
891860
self.config = config
@@ -901,23 +870,23 @@ private class Gemma3nDecoderLayer: Module {
901870
== "sliding_attention"
902871

903872
self._mlp.wrappedValue = MLP(config: config, layerIdx: layerIdx)
904-
self._inputLayernorm.wrappedValue = Gemma3nRMSNormWithScale(
873+
self._inputLayernorm.wrappedValue = Gemma3nRMSNorm(
905874
dim: hiddenSize,
906875
eps: config.rmsNormEps,
907876
scaleShift: 0.0
908877
)
909878

910-
self._postAttentionLayernorm.wrappedValue = Gemma3nRMSNormWithScale(
879+
self._postAttentionLayernorm.wrappedValue = Gemma3nRMSNorm(
911880
dim: hiddenSize,
912881
eps: config.rmsNormEps,
913882
scaleShift: 0.0
914883
)
915-
self._preFeedforwardLayernorm.wrappedValue = Gemma3nRMSNormWithScale(
884+
self._preFeedforwardLayernorm.wrappedValue = Gemma3nRMSNorm(
916885
dim: hiddenSize,
917886
eps: config.rmsNormEps,
918887
scaleShift: 0.0
919888
)
920-
self._postFeedforwardLayernorm.wrappedValue = Gemma3nRMSNormWithScale(
889+
self._postFeedforwardLayernorm.wrappedValue = Gemma3nRMSNorm(
921890
dim: hiddenSize,
922891
eps: config.rmsNormEps,
923892
scaleShift: 0.0
@@ -936,7 +905,7 @@ private class Gemma3nDecoderLayer: Module {
936905
hiddenSize,
937906
bias: false
938907
)
939-
self._postPerLayerInputNorm.wrappedValue = Gemma3nRMSNormWithScale(
908+
self._postPerLayerInputNorm.wrappedValue = Gemma3nRMSNorm(
940909
dim: hiddenSize,
941910
eps: config.rmsNormEps,
942911
scaleShift: 0.0
@@ -1049,13 +1018,12 @@ private class Gemma3Model: Module {
10491018
@ModuleInfo(key: "layers") var layers: [Gemma3nDecoderLayer]
10501019
@ModuleInfo(key: "embed_tokens_per_layer") var embedTokensPerLayer: Embedding
10511020
@ModuleInfo(key: "per_layer_model_projection") var perLayerModelProjection: Linear
1052-
@ModuleInfo(key: "per_layer_projection_norm") var perLayerProjectionNorm:
1053-
Gemma3nRMSNormWithScale
1021+
@ModuleInfo(key: "per_layer_projection_norm") var perLayerProjectionNorm: Gemma3nRMSNorm
10541022

10551023
@ModuleInfo(key: "altup_projections") var altupProjections: [Linear]
10561024
@ModuleInfo(key: "altup_unembed_projections") var altupUnembedProjections: [Linear]
10571025

1058-
@ModuleInfo var norm: Gemma3nRMSNormWithScale
1026+
@ModuleInfo var norm: Gemma3nRMSNorm
10591027
@ModuleInfo(key: "rope_embedding") var ropeEmbedding: Gemma3nRotaryEmbedding
10601028
@ModuleInfo(key: "rope_embedding_local") var ropeEmbeddingLocal: Gemma3nRotaryEmbedding
10611029

@@ -1090,7 +1058,7 @@ private class Gemma3Model: Module {
10901058
bias: false
10911059
)
10921060

1093-
self._perLayerProjectionNorm.wrappedValue = Gemma3nRMSNormWithScale(
1061+
self._perLayerProjectionNorm.wrappedValue = Gemma3nRMSNorm(
10941062
dim: config.hiddenSizePerLayerInput,
10951063
eps: config.rmsNormEps,
10961064
scaleShift: 0.0
@@ -1103,7 +1071,7 @@ private class Gemma3Model: Module {
11031071
Linear(config.hiddenSize, config.hiddenSize, bias: false)
11041072
}
11051073

1106-
self._norm.wrappedValue = Gemma3nRMSNormWithScale(
1074+
self._norm.wrappedValue = Gemma3nRMSNorm(
11071075
dim: config.hiddenSize,
11081076
eps: config.rmsNormEps,
11091077
scaleShift: 0.0
@@ -1375,11 +1343,11 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
13751343
let textHiddenSize: Int
13761344

13771345
@ModuleInfo var embedding: Embedding
1378-
@ModuleInfo(key: "hard_embedding_norm") var hardEmbeddingNorm: Gemma3nRMSNormWithScale
1379-
@ModuleInfo(key: "soft_embedding_norm") var softEmbeddingNorm: Gemma3nRMSNormWithScale
1346+
@ModuleInfo(key: "hard_embedding_norm") var hardEmbeddingNorm: Gemma3nRMSNorm
1347+
@ModuleInfo(key: "soft_embedding_norm") var softEmbeddingNorm: Gemma3nRMSNorm
13801348
@ModuleInfo(key: "embedding_projection") var embeddingProjection: Linear
13811349
@ModuleInfo(key: "embedding_post_projection_norm") var embeddingPostProjectionNorm:
1382-
Gemma3nRMSNormNoScale
1350+
Gemma3nRMSNorm
13831351

13841352
init(multimodalConfig: any MultimodalConfig, textConfig: TextConfig) {
13851353
self.multimodalHiddenSize = multimodalConfig.hiddenSize
@@ -1392,11 +1360,11 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
13921360
embeddingCount: vocabSize,
13931361
dimensions: multimodalHiddenSize
13941362
)
1395-
self._hardEmbeddingNorm.wrappedValue = Gemma3nRMSNormWithScale(
1363+
self._hardEmbeddingNorm.wrappedValue = Gemma3nRMSNorm(
13961364
dim: multimodalHiddenSize,
13971365
eps: eps
13981366
)
1399-
self._softEmbeddingNorm.wrappedValue = Gemma3nRMSNormWithScale(
1367+
self._softEmbeddingNorm.wrappedValue = Gemma3nRMSNorm(
14001368
dim: multimodalHiddenSize,
14011369
eps: eps
14021370
)
@@ -1405,7 +1373,7 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
14051373
textHiddenSize,
14061374
bias: false
14071375
)
1408-
self._embeddingPostProjectionNorm.wrappedValue = Gemma3nRMSNormNoScale(
1376+
self._embeddingPostProjectionNorm.wrappedValue = Gemma3nRMSNorm(
14091377
dim: textHiddenSize,
14101378
eps: eps
14111379
)
@@ -2538,21 +2506,21 @@ private class Gemma3nAudioConformerAttention: Module {
25382506
let postInFeatures: Int
25392507
private let _gradientClipping: MLXArray
25402508

2541-
@ModuleInfo(key: "pre_attn_norm") var preAttnNorm: Gemma3nRMSNormWithScale
2509+
@ModuleInfo(key: "pre_attn_norm") var preAttnNorm: Gemma3nRMSNorm
25422510
@ModuleInfo var attn: Gemma3nAudioAttention
25432511
@ModuleInfo var post: Linear
2544-
@ModuleInfo(key: "post_norm") var postNorm: Gemma3nRMSNormWithScale
2512+
@ModuleInfo(key: "post_norm") var postNorm: Gemma3nRMSNorm
25452513

25462514
init(config: AudioConfig) {
25472515
self.config = config
25482516
let headDim = config.hiddenSize / config.confNumAttentionHeads
25492517
self.postInFeatures = config.hiddenSize
25502518
self._gradientClipping = MLXArray(config.gradientClipping)
25512519

2552-
self._preAttnNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
2520+
self._preAttnNorm.wrappedValue = Gemma3nRMSNorm(dim: config.hiddenSize)
25532521
self._attn.wrappedValue = Gemma3nAudioAttention(config: config)
25542522
self._post.wrappedValue = Linear(postInFeatures, config.hiddenSize, bias: false)
2555-
self._postNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
2523+
self._postNorm.wrappedValue = Gemma3nRMSNorm(dim: config.hiddenSize)
25562524

25572525
super.init()
25582526
}
@@ -2581,20 +2549,20 @@ private class Gemma3nAudioConformerFeedForward: Module {
25812549
private let _gradientClipping: MLXArray
25822550
private let _postLayerScale: MLXArray
25832551

2584-
@ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma3nRMSNormWithScale
2552+
@ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma3nRMSNorm
25852553
@ModuleInfo(key: "ffw_layer_1") var ffwLayer1: Linear
25862554
@ModuleInfo(key: "ffw_layer_2") var ffwLayer2: Linear
2587-
@ModuleInfo(key: "post_layer_norm") var postLayerNorm: Gemma3nRMSNormWithScale
2555+
@ModuleInfo(key: "post_layer_norm") var postLayerNorm: Gemma3nRMSNorm
25882556

25892557
init(config: AudioConfig) {
25902558
self.config = config
25912559
self._gradientClipping = MLXArray(config.gradientClipping)
25922560
self._postLayerScale = MLXArray(config.confResidualWeight)
25932561

2594-
self._preLayerNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
2562+
self._preLayerNorm.wrappedValue = Gemma3nRMSNorm(dim: config.hiddenSize)
25952563
self._ffwLayer1.wrappedValue = Linear(config.hiddenSize, config.hiddenSize * 4, bias: false)
25962564
self._ffwLayer2.wrappedValue = Linear(config.hiddenSize * 4, config.hiddenSize, bias: false)
2597-
self._postLayerNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
2565+
self._postLayerNorm.wrappedValue = Gemma3nRMSNorm(dim: config.hiddenSize)
25982566

25992567
super.init()
26002568
}
@@ -2618,18 +2586,18 @@ private class Gemma3nAudioConformerLightConv1d: Module {
26182586
private let _gradientClipping: MLXArray
26192587
let causalPadding: Int
26202588

2621-
@ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma3nRMSNormWithScale
2589+
@ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma3nRMSNorm
26222590
@ModuleInfo(key: "linear_start") var linearStart: Linear
26232591
@ModuleInfo(key: "depthwise_conv1d") var depthwiseConv1d: Conv1d
2624-
@ModuleInfo(key: "conv_norm") var convNorm: Gemma3nRMSNormWithScale
2592+
@ModuleInfo(key: "conv_norm") var convNorm: Gemma3nRMSNorm
26252593
@ModuleInfo(key: "linear_end") var linearEnd: Linear
26262594

26272595
init(config: AudioConfig) {
26282596
self.config = config
26292597
self._gradientClipping = MLXArray(config.gradientClipping)
26302598
self.causalPadding = config.confConvKernelSize - 1
26312599

2632-
self._preLayerNorm.wrappedValue = Gemma3nRMSNormWithScale(
2600+
self._preLayerNorm.wrappedValue = Gemma3nRMSNorm(
26332601
dim: config.hiddenSize,
26342602
eps: config.rmsNormEps
26352603
)
@@ -2647,7 +2615,7 @@ private class Gemma3nAudioConformerLightConv1d: Module {
26472615
groups: config.hiddenSize,
26482616
bias: false
26492617
)
2650-
self._convNorm.wrappedValue = Gemma3nRMSNormWithScale(
2618+
self._convNorm.wrappedValue = Gemma3nRMSNorm(
26512619
dim: config.hiddenSize,
26522620
eps: config.rmsNormEps
26532621
)
@@ -2690,7 +2658,7 @@ private class Gemma3nAudioConformerBlock: Module {
26902658
@ModuleInfo var attention: Gemma3nAudioConformerAttention
26912659
@ModuleInfo var lconv1d: Gemma3nAudioConformerLightConv1d
26922660
@ModuleInfo(key: "ffw_layer_end") var ffwLayerEnd: Gemma3nAudioConformerFeedForward
2693-
@ModuleInfo var norm: Gemma3nRMSNormWithScale
2661+
@ModuleInfo var norm: Gemma3nRMSNorm
26942662

26952663
init(config: AudioConfig) {
26962664
self.config = config
@@ -2700,7 +2668,7 @@ private class Gemma3nAudioConformerBlock: Module {
27002668
self._attention.wrappedValue = Gemma3nAudioConformerAttention(config: config)
27012669
self._lconv1d.wrappedValue = Gemma3nAudioConformerLightConv1d(config: config)
27022670
self._ffwLayerEnd.wrappedValue = Gemma3nAudioConformerFeedForward(config: config)
2703-
self._norm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
2671+
self._norm.wrappedValue = Gemma3nRMSNorm(dim: config.hiddenSize)
27042672

27052673
super.init()
27062674
}

0 commit comments

Comments
 (0)