@@ -395,12 +395,12 @@ private protocol Gemma3nRMSNormProtocol: UnaryLayer {
395
395
private class Gemma3nRMSNormWithScale : Module , Gemma3nRMSNormProtocol {
396
396
let eps : Float
397
397
let scaleShift : Float
398
- @ ModuleInfo var weight : MLXArray
398
+ let weight : MLXArray
399
399
400
400
init ( dim: Int , eps: Float = 1e-6 , scaleShift: Float = 0.0 ) {
401
401
self . eps = eps
402
402
self . scaleShift = scaleShift
403
- self . _weight . wrappedValue = MLXArray . ones ( [ dim] )
403
+ self . weight = MLXArray . ones ( [ dim] )
404
404
super. init ( )
405
405
}
406
406
@@ -495,8 +495,8 @@ private class Gemma3nRotaryEmbedding: Module {
495
495
let originalMaxSeqLen : Int
496
496
let config : TextConfig
497
497
let attentionScaling : Float
498
- let _invFreq : MLXArray
499
- let _originalInvFreq : MLXArray
498
+ private let _invFreq : MLXArray
499
+ private let _originalInvFreq : MLXArray
500
500
501
501
init ( config: TextConfig ) {
502
502
if let ropeScaling = config. ropeScaling {
@@ -750,7 +750,7 @@ private class Gemma3nAltUp: Module {
750
750
@ModuleInfo ( key: " prediction_coefs " ) var predictionCoefs : Linear
751
751
@ModuleInfo ( key: " modality_router " ) var modalityRouter : Linear
752
752
@ModuleInfo ( key: " router_norm " ) var routerNorm : Gemma3nRMSNormWithScale
753
- let _routerInputScale : MLXArray
753
+ private let _routerInputScale : MLXArray
754
754
755
755
let config : TextConfig
756
756
@@ -2340,7 +2340,7 @@ private class Gemma3nAudioAttention: Module {
2340
2340
2341
2341
@ModuleInfo ( key: " relative_position_embedding " ) var relativePositionEmbedding :
2342
2342
Gemma3nAudioRelativePositionEmbedding
2343
- @ ModuleInfo ( key : " per_dim_scale " ) var perDimScale : MLXArray
2343
+ private let _perDimScale : MLXArray
2344
2344
@ModuleInfo ( key: " q_proj " ) var qProj : Linear
2345
2345
@ModuleInfo ( key: " k_proj " ) var kProj : Linear
2346
2346
@ModuleInfo ( key: " v_proj " ) var vProj : Linear
@@ -2359,7 +2359,7 @@ private class Gemma3nAudioAttention: Module {
2359
2359
2360
2360
self . _relativePositionEmbedding. wrappedValue = Gemma3nAudioRelativePositionEmbedding (
2361
2361
config: config)
2362
- self . _perDimScale. wrappedValue = MLXArray . zeros ( [ headDim] )
2362
+ self . _perDimScale = MLXArray . zeros ( [ headDim] )
2363
2363
2364
2364
self . _qProj. wrappedValue = Linear ( hiddenSize, numHeads * headDim, bias: false )
2365
2365
self . _kProj. wrappedValue = Linear ( hiddenSize, numHeads * headDim, bias: false )
@@ -2460,7 +2460,7 @@ private class Gemma3nAudioAttention: Module {
2460
2460
Array ( x. shape. dropLast ( ) ) + [ numHeads, headDim]
2461
2461
)
2462
2462
2463
- let perDimScaleSp = logAddExp ( perDimScale , MLXArray ( 0.0 ) )
2463
+ let perDimScaleSp = logAddExp ( _perDimScale , MLXArray ( 0.0 ) )
2464
2464
let broadcastShape = [ 1 , 1 , 1 , headDim]
2465
2465
let perDimScaleSpBroadcast = perDimScaleSp. reshaped ( broadcastShape)
2466
2466
let scaledQueryStates = queryStates * qScale * perDimScaleSpBroadcast
@@ -2728,19 +2728,19 @@ private class Gemma3nAudioConformerBlock: Module {
2728
2728
// MARK: - Layer Scale 2D
2729
2729
private class LayerScale2d : Module , UnaryLayer {
2730
2730
let inplace : Bool
2731
- @ ModuleInfo var gamma : MLXArray
2731
+ private let _gamma : MLXArray
2732
2732
2733
2733
init ( dim: Int , initValues: Float = 1e-5 , inplace: Bool = false ) {
2734
2734
self . inplace = inplace
2735
- self . _gamma. wrappedValue = MLXArray ( initValues) * MLXArray. ones ( [ dim] )
2735
+ self . _gamma = MLXArray ( initValues) * MLXArray. ones ( [ dim] )
2736
2736
super. init ( )
2737
2737
}
2738
2738
2739
2739
func callAsFunction( _ x: MLXArray ) -> MLXArray {
2740
2740
if inplace {
2741
- return x * gamma
2741
+ return x * _gamma
2742
2742
} else {
2743
- return x * gamma
2743
+ return x * _gamma
2744
2744
}
2745
2745
}
2746
2746
}
0 commit comments