Skip to content

Commit 91449cf

Browse files
committed
update for computed keys, key names
1 parent e417d7e commit 91449cf

File tree

1 file changed

+53
-51
lines changed

1 file changed

+53
-51
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,8 @@ private class Gemma3nRotaryEmbedding: Module {
495495
let originalMaxSeqLen: Int
496496
let config: TextConfig
497497
let attentionScaling: Float
498-
let invFreq: MLXArray
499-
let originalInvFreq: MLXArray
498+
let _invFreq: MLXArray
499+
let _originalInvFreq: MLXArray
500500

501501
init(config: TextConfig) {
502502
if let ropeScaling = config.ropeScaling {
@@ -516,8 +516,8 @@ private class Gemma3nRotaryEmbedding: Module {
516516
self.attentionScaling = 1.0
517517

518518
let (invFreq, _) = Self.computeDefaultRopeParameters(config: config)
519-
self.invFreq = MLXArray(invFreq).asType(.float32)
520-
self.originalInvFreq = MLXArray(invFreq).asType(.float32)
519+
self._invFreq = MLXArray(invFreq).asType(.float32)
520+
self._originalInvFreq = MLXArray(invFreq).asType(.float32)
521521

522522
super.init()
523523
}
@@ -538,7 +538,7 @@ private class Gemma3nRotaryEmbedding: Module {
538538
}
539539

540540
func callAsFunction(_ x: MLXArray, positionIds: MLXArray) -> (MLXArray, MLXArray) {
541-
let invFreqExpanded = expandedDimensions(invFreq, axes: [0, 2])
541+
let invFreqExpanded = expandedDimensions(_invFreq, axes: [0, 2])
542542
let positionIdsExpanded = expandedDimensions(positionIds.asType(.float32), axes: [1])
543543

544544
let freqs = matmul(
@@ -750,7 +750,7 @@ private class Gemma3nAltUp: Module {
750750
@ModuleInfo(key: "prediction_coefs") var predictionCoefs: Linear
751751
@ModuleInfo(key: "modality_router") var modalityRouter: Linear
752752
@ModuleInfo(key: "router_norm") var routerNorm: Gemma3nRMSNormWithScale
753-
@ModuleInfo(key: "router_input_scale") var routerInputScale: MLXArray
753+
let _routerInputScale: MLXArray
754754

755755
let config: TextConfig
756756

@@ -778,14 +778,14 @@ private class Gemma3nAltUp: Module {
778778
eps: config.rmsNormEps,
779779
scaleShift: 0.0
780780
)
781-
self._routerInputScale.wrappedValue = MLXArray(pow(Float(config.hiddenSize), -1.0))
781+
self._routerInputScale = MLXArray(pow(Float(config.hiddenSize), -1.0))
782782

783783
super.init()
784784
}
785785

786786
func computeRouterModalities(_ x: MLXArray) -> MLXArray {
787787
let routerInputs =
788-
routerNorm(x) * routerInputScale.asType(routerNorm.weight.dtype)
788+
routerNorm(x) * _routerInputScale.asType(routerNorm.weight.dtype)
789789
let routed = modalityRouter(routerInputs).asType(.float32)
790790
return tanh(routed)
791791
}
@@ -1057,8 +1057,8 @@ private class Gemma3Model: Module {
10571057
let vocabSize: Int
10581058
let vocabSizePerLayerInput: Int
10591059
let numHiddenLayers: Int
1060-
private let perLayerProjectionScale: MLXArray
1061-
private let perLayerInputScale: MLXArray
1060+
private let _perLayerProjectionScale: MLXArray
1061+
private let _perLayerInputScale: MLXArray
10621062

10631063
@ModuleInfo(key: "embed_tokens") var embedTokens: Gemma3nTextScaledWordEmbedding
10641064
@ModuleInfo(key: "layers") var layers: [Gemma3nDecoderLayer] // This is correct!
@@ -1125,8 +1125,8 @@ private class Gemma3Model: Module {
11251125
scaleShift: 0.0
11261126
)
11271127

1128-
self.perLayerProjectionScale = MLXArray(pow(Float(hiddenSize), -0.5))
1129-
self.perLayerInputScale = rsqrt(MLXArray(2.0))
1128+
self._perLayerProjectionScale = MLXArray(pow(Float(hiddenSize), -0.5))
1129+
self._perLayerInputScale = rsqrt(MLXArray(2.0))
11301130

11311131
self._ropeEmbedding.wrappedValue = Gemma3nRotaryEmbedding(config: config)
11321132

@@ -1261,7 +1261,8 @@ private class Gemma3Model: Module {
12611261

12621262
func projectPerLayerInputs(_ inputsEmbeds: MLXArray, perLayerInputs: MLXArray?) -> MLXArray {
12631263
var perLayerProjection = perLayerModelProjection(inputsEmbeds)
1264-
perLayerProjection = perLayerProjection * perLayerProjectionScale.asType(inputsEmbeds.dtype)
1264+
perLayerProjection =
1265+
perLayerProjection * _perLayerProjectionScale.asType(inputsEmbeds.dtype)
12651266

12661267
perLayerProjection = perLayerProjection.reshaped(
12671268
Array(inputsEmbeds.shape.dropLast()) + [
@@ -1282,7 +1283,7 @@ private class Gemma3Model: Module {
12821283
}
12831284

12841285
return (perLayerProjection + adjustedPerLayerInputs)
1285-
* perLayerInputScale.asType(inputsEmbeds.dtype)
1286+
* _perLayerInputScale.asType(inputsEmbeds.dtype)
12861287
}
12871288
}
12881289

@@ -2335,8 +2336,8 @@ private class Gemma3nAudioSubSampleConvProjection: Module {
23352336
let config: AudioConfig
23362337
let inputProjInFeatures: Int
23372338

2338-
@ModuleInfo var conv0: Gemma3nAudioSSCPConvBlock
2339-
@ModuleInfo var conv1: Gemma3nAudioSSCPConvBlock
2339+
@ModuleInfo(key: "conv_0") var conv0: Gemma3nAudioSSCPConvBlock
2340+
@ModuleInfo(key: "conv_1") var conv1: Gemma3nAudioSSCPConvBlock
23402341
@ModuleInfo(key: "input_proj_linear") var inputProjLinear: Linear
23412342

23422343
init(config: AudioConfig) {
@@ -2625,7 +2626,7 @@ private class Gemma3nAudioAttention: Module {
26252626
private class Gemma3nAudioConformerAttention: Module {
26262627
let config: AudioConfig
26272628
let postInFeatures: Int
2628-
private let gradientClipping: MLXArray
2629+
private let _gradientClipping: MLXArray
26292630

26302631
@ModuleInfo var preAttnNorm: Gemma3nRMSNormWithScale
26312632
@ModuleInfo var attn: Gemma3nAudioAttention
@@ -2636,7 +2637,7 @@ private class Gemma3nAudioConformerAttention: Module {
26362637
self.config = config
26372638
let headDim = config.hiddenSize / config.confNumAttentionHeads
26382639
self.postInFeatures = config.hiddenSize
2639-
self.gradientClipping = MLXArray(config.gradientClipping)
2640+
self._gradientClipping = MLXArray(config.gradientClipping)
26402641

26412642
self._preAttnNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
26422643
self._attn.wrappedValue = Gemma3nAudioAttention(config: config)
@@ -2648,7 +2649,7 @@ private class Gemma3nAudioConformerAttention: Module {
26482649

26492650
func callAsFunction(_ x: MLXArray, mask: MLXArray) -> MLXArray {
26502651
let audioencodingsInputToAttn = x
2651-
let clippedX = clip(x, min: -gradientClipping, max: gradientClipping)
2652+
let clippedX = clip(x, min: -_gradientClipping, max: _gradientClipping)
26522653
let audioencodingsNorm = preAttnNorm(clippedX)
26532654
let audioencodingsAttnOut = attn(audioencodingsNorm, mask: mask)
26542655

@@ -2659,26 +2660,26 @@ private class Gemma3nAudioConformerAttention: Module {
26592660
let audioencodingsReshaped = audioencodingsAttnOut.reshaped([b, t, numHeads * headDim])
26602661

26612662
let postResult = post(audioencodingsReshaped)
2662-
let clippedPost = clip(postResult, min: -gradientClipping, max: gradientClipping)
2663+
let clippedPost = clip(postResult, min: -_gradientClipping, max: _gradientClipping)
26632664
return audioencodingsInputToAttn + postNorm(clippedPost)
26642665
}
26652666
}
26662667

26672668
// MARK: - Conformer Feed Forward
26682669
private class Gemma3nAudioConformerFeedForward: Module {
26692670
let config: AudioConfig
2670-
private let gradientClipping: MLXArray
2671-
private let postLayerScale: MLXArray
2671+
private let _gradientClipping: MLXArray
2672+
private let _postLayerScale: MLXArray
26722673

2673-
@ModuleInfo var preLayerNorm: Gemma3nRMSNormWithScale
2674-
@ModuleInfo var ffwLayer1: Linear
2675-
@ModuleInfo var ffwLayer2: Linear
2676-
@ModuleInfo var postLayerNorm: Gemma3nRMSNormWithScale
2674+
@ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma3nRMSNormWithScale
2675+
@ModuleInfo(key: "ffw_layer_1") var ffwLayer1: Linear
2676+
@ModuleInfo(key: "ffw_layer_2") var ffwLayer2: Linear
2677+
@ModuleInfo(key: "post_layer_norm") var postLayerNorm: Gemma3nRMSNormWithScale
26772678

26782679
init(config: AudioConfig) {
26792680
self.config = config
2680-
self.gradientClipping = MLXArray(config.gradientClipping)
2681-
self.postLayerScale = MLXArray(config.confResidualWeight)
2681+
self._gradientClipping = MLXArray(config.gradientClipping)
2682+
self._postLayerScale = MLXArray(config.confResidualWeight)
26822683

26832684
self._preLayerNorm.wrappedValue = Gemma3nRMSNormWithScale(dim: config.hiddenSize)
26842685
self._ffwLayer1.wrappedValue = Linear(config.hiddenSize, config.hiddenSize * 4, bias: false)
@@ -2690,32 +2691,32 @@ private class Gemma3nAudioConformerFeedForward: Module {
26902691

26912692
func callAsFunction(_ x: MLXArray) -> MLXArray {
26922693
let residual = x
2693-
let clippedX = clip(x, min: -gradientClipping, max: gradientClipping)
2694+
let clippedX = clip(x, min: -_gradientClipping, max: _gradientClipping)
26942695
var result = preLayerNorm(clippedX)
26952696
result = ffwLayer1(result)
26962697
result = silu(result)
26972698
result = ffwLayer2(result)
2698-
let clippedResult = clip(result, min: -gradientClipping, max: gradientClipping)
2699+
let clippedResult = clip(result, min: -_gradientClipping, max: _gradientClipping)
26992700
let normedResult = postLayerNorm(clippedResult)
2700-
return residual + (normedResult * postLayerScale)
2701+
return residual + (normedResult * _postLayerScale)
27012702
}
27022703
}
27032704

27042705
// MARK: - Conformer Light Conv1D
27052706
private class Gemma3nAudioConformerLightConv1d: Module {
27062707
let config: AudioConfig
2707-
private let gradientClipping: MLXArray
2708+
private let _gradientClipping: MLXArray
27082709
let causalPadding: Int
27092710

2710-
@ModuleInfo var preLayerNorm: Gemma3nRMSNormWithScale
2711-
@ModuleInfo var linearStart: Linear
2712-
@ModuleInfo var depthwiseConv1d: Conv1d
2713-
@ModuleInfo var convNorm: Gemma3nRMSNormWithScale
2714-
@ModuleInfo var linearEnd: Linear
2711+
@ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma3nRMSNormWithScale
2712+
@ModuleInfo(key: "linear_start") var linearStart: Linear
2713+
@ModuleInfo(key: "depthwise_conv1d") var depthwiseConv1d: Conv1d
2714+
@ModuleInfo(key: "conv_norm") var convNorm: Gemma3nRMSNormWithScale
2715+
@ModuleInfo(key: "linear_end") var linearEnd: Linear
27152716

27162717
init(config: AudioConfig) {
27172718
self.config = config
2718-
self.gradientClipping = MLXArray(config.gradientClipping)
2719+
self._gradientClipping = MLXArray(config.gradientClipping)
27192720
self.causalPadding = config.confConvKernelSize - 1
27202721

27212722
self._preLayerNorm.wrappedValue = Gemma3nRMSNormWithScale(
@@ -2761,7 +2762,7 @@ private class Gemma3nAudioConformerLightConv1d: Module {
27612762
)
27622763

27632764
result = depthwiseConv1d(paddedAudio.transposed(0, 2, 1))
2764-
result = clip(result, min: -gradientClipping, max: gradientClipping)
2765+
result = clip(result, min: -_gradientClipping, max: _gradientClipping)
27652766
result = convNorm(result)
27662767
result = silu(result)
27672768
result = linearEnd(result)
@@ -2967,11 +2968,11 @@ private class ConvNormAct: Module, UnaryLayer {
29672968
// MARK: - Universal Inverted Residual
29682969
private class UniversalInvertedResidual: Module, UnaryLayer {
29692970
let hasSkip: Bool
2970-
@ModuleInfo var dwStart: UnaryLayer
2971-
@ModuleInfo var pwExp: ConvNormAct
2972-
@ModuleInfo var dwMid: UnaryLayer
2973-
@ModuleInfo var pwProj: ConvNormAct
2974-
@ModuleInfo var layerScale: UnaryLayer
2971+
@ModuleInfo(key: "dw_start") var dwStart: UnaryLayer
2972+
@ModuleInfo(key: "pw_exp") var pwExp: ConvNormAct
2973+
@ModuleInfo(key: "dw_mid") var dwMid: UnaryLayer
2974+
@ModuleInfo(key: "pw_proj") var pwProj: ConvNormAct
2975+
@ModuleInfo(key: "layer_scale") var layerScale: UnaryLayer
29752976

29762977
init(
29772978
inChannels: Int,
@@ -3088,9 +3089,9 @@ private class UniversalInvertedResidual: Module, UnaryLayer {
30883089
// MARK: - Edge Residual
30893090
private class EdgeResidual: Module, UnaryLayer {
30903091
let hasSkip: Bool
3091-
@ModuleInfo var convExp: Conv2d
3092+
@ModuleInfo(key: "conv_exp") var convExp: Conv2d
30923093
@ModuleInfo var bn1: RMSNormAct2d
3093-
@ModuleInfo var convPwl: Conv2d
3094+
@ModuleInfo(key: "conv_pwl") var convPwl: Conv2d
30943095
@ModuleInfo var bn2: RMSNormAct2d
30953096

30963097
init(
@@ -3184,9 +3185,9 @@ private class MultiQueryAttention2d: Module {
31843185

31853186
@ModuleInfo var keyProj: Conv2d
31863187
@ModuleInfo var valueProj: Conv2d
3187-
@ModuleInfo var attnDrop: UnaryLayer
3188+
@ModuleInfo(key: "attn_drop") var attnDrop: UnaryLayer
31883189
@ModuleInfo var outputProj: Conv2d
3189-
@ModuleInfo var projDrop: UnaryLayer
3190+
@ModuleInfo(key: "proj_drop") var projDrop: UnaryLayer
31903191

31913192
init(
31923193
dim: Int,
@@ -3681,7 +3682,7 @@ private class MobileNetV5MultiScaleFusionAdapter: Module {
36813682

36823683
// MARK: - Vision Tower
36833684
private class VisionTower: Module {
3684-
@ModuleInfo var convStem: ConvNormAct
3685+
@ModuleInfo(key: "conv_stem") var convStem: ConvNormAct
36853686
@ModuleInfo var blocks: [[UnaryLayer]]
36863687
@ModuleInfo var msfa: MobileNetV5MultiScaleFusionAdapter
36873688

@@ -3800,7 +3801,7 @@ private class VisionTower: Module {
38003801
// MARK: - Complete Vision Model
38013802
private class Gemma3nVisionModel: Module {
38023803
let modelType: String
3803-
@ModuleInfo var timmModel: VisionTower
3804+
@ModuleInfo(key: "timm_model") var timmModel: VisionTower
38043805

38053806
init(config: VisionConfig) {
38063807
self.modelType = config.modelType
@@ -3856,7 +3857,8 @@ private class Gemma3nVisionModel: Module {
38563857
private class Gemma3nAudioModel: Module {
38573858
let config: AudioConfig
38583859

3859-
@ModuleInfo var subsampleConvProjection: Gemma3nAudioSubSampleConvProjection
3860+
@ModuleInfo(key: "subsample_conv_projection") var subsampleConvProjection:
3861+
Gemma3nAudioSubSampleConvProjection
38603862
@ModuleInfo var conformer: [Gemma3nAudioConformerBlock]
38613863

38623864
init(config: AudioConfig) {

0 commit comments

Comments
 (0)