Skip to content

Commit 6aa0834

Browse files
committed
Clean up
1 parent 02767d5 commit 6aa0834

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,12 @@ private protocol Gemma3nRMSNormProtocol: UnaryLayer {
395395
private class Gemma3nRMSNormWithScale: Module, Gemma3nRMSNormProtocol {
396396
let eps: Float
397397
let scaleShift: Float
398-
@ModuleInfo var weight: MLXArray
398+
let weight: MLXArray
399399

400400
init(dim: Int, eps: Float = 1e-6, scaleShift: Float = 0.0) {
401401
self.eps = eps
402402
self.scaleShift = scaleShift
403-
self._weight.wrappedValue = MLXArray.ones([dim])
403+
self.weight = MLXArray.ones([dim])
404404
super.init()
405405
}
406406

@@ -495,8 +495,8 @@ private class Gemma3nRotaryEmbedding: Module {
495495
let originalMaxSeqLen: Int
496496
let config: TextConfig
497497
let attentionScaling: Float
498-
let _invFreq: MLXArray
499-
let _originalInvFreq: MLXArray
498+
private let _invFreq: MLXArray
499+
private let _originalInvFreq: MLXArray
500500

501501
init(config: TextConfig) {
502502
if let ropeScaling = config.ropeScaling {
@@ -750,7 +750,7 @@ private class Gemma3nAltUp: Module {
750750
@ModuleInfo(key: "prediction_coefs") var predictionCoefs: Linear
751751
@ModuleInfo(key: "modality_router") var modalityRouter: Linear
752752
@ModuleInfo(key: "router_norm") var routerNorm: Gemma3nRMSNormWithScale
753-
let _routerInputScale: MLXArray
753+
private let _routerInputScale: MLXArray
754754

755755
let config: TextConfig
756756

@@ -2340,7 +2340,7 @@ private class Gemma3nAudioAttention: Module {
23402340

23412341
@ModuleInfo(key: "relative_position_embedding") var relativePositionEmbedding:
23422342
Gemma3nAudioRelativePositionEmbedding
2343-
@ModuleInfo(key: "per_dim_scale") var perDimScale: MLXArray
2343+
private let _perDimScale: MLXArray
23442344
@ModuleInfo(key: "q_proj") var qProj: Linear
23452345
@ModuleInfo(key: "k_proj") var kProj: Linear
23462346
@ModuleInfo(key: "v_proj") var vProj: Linear
@@ -2359,7 +2359,7 @@ private class Gemma3nAudioAttention: Module {
23592359

23602360
self._relativePositionEmbedding.wrappedValue = Gemma3nAudioRelativePositionEmbedding(
23612361
config: config)
2362-
self._perDimScale.wrappedValue = MLXArray.zeros([headDim])
2362+
self._perDimScale = MLXArray.zeros([headDim])
23632363

23642364
self._qProj.wrappedValue = Linear(hiddenSize, numHeads * headDim, bias: false)
23652365
self._kProj.wrappedValue = Linear(hiddenSize, numHeads * headDim, bias: false)
@@ -2460,7 +2460,7 @@ private class Gemma3nAudioAttention: Module {
24602460
Array(x.shape.dropLast()) + [numHeads, headDim]
24612461
)
24622462

2463-
let perDimScaleSp = logAddExp(perDimScale, MLXArray(0.0))
2463+
let perDimScaleSp = logAddExp(_perDimScale, MLXArray(0.0))
24642464
let broadcastShape = [1, 1, 1, headDim]
24652465
let perDimScaleSpBroadcast = perDimScaleSp.reshaped(broadcastShape)
24662466
let scaledQueryStates = queryStates * qScale * perDimScaleSpBroadcast
@@ -2728,19 +2728,19 @@ private class Gemma3nAudioConformerBlock: Module {
27282728
// MARK: - Layer Scale 2D
27292729
private class LayerScale2d: Module, UnaryLayer {
27302730
let inplace: Bool
2731-
@ModuleInfo var gamma: MLXArray
2731+
private let _gamma: MLXArray
27322732

27332733
init(dim: Int, initValues: Float = 1e-5, inplace: Bool = false) {
27342734
self.inplace = inplace
2735-
self._gamma.wrappedValue = MLXArray(initValues) * MLXArray.ones([dim])
2735+
self._gamma = MLXArray(initValues) * MLXArray.ones([dim])
27362736
super.init()
27372737
}
27382738

27392739
func callAsFunction(_ x: MLXArray) -> MLXArray {
27402740
if inplace {
2741-
return x * gamma
2741+
return x * _gamma
27422742
} else {
2743-
return x * gamma
2743+
return x * _gamma
27442744
}
27452745
}
27462746
}

0 commit comments

Comments
 (0)