Skip to content

Commit 410c89c

Browse files
committed
Clean up
1 parent 0d6d026 commit 410c89c

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,18 +1146,18 @@ private class Gemma3Model: Module {
11461146
perLayerInputs: MLXArray? = nil
11471147
) -> MLXArray {
11481148
var h: MLXArray
1149-
if let inputsEmbeds = inputsEmbeds {
1149+
if let inputsEmbeds {
11501150
h = inputsEmbeds
1151-
} else if let inputs = inputs {
1151+
} else if let inputs {
11521152
h = embedTokens(inputs)
11531153
} else {
11541154
fatalError("Either inputs or inputsEmbeds must be provided")
11551155
}
11561156

11571157
let perLayerInputsProcessed: MLXArray
1158-
if let perLayerInputs = perLayerInputs {
1158+
if let perLayerInputs {
11591159
perLayerInputsProcessed = perLayerInputs
1160-
} else if let inputs = inputs {
1160+
} else if let inputs {
11611161
perLayerInputsProcessed = getPerLayerInputs(inputs)
11621162
} else {
11631163
fatalError("Cannot generate per layer inputs without input ids")
@@ -1213,7 +1213,7 @@ private class Gemma3Model: Module {
12131213
== "global_attention"
12141214

12151215
let localMask: MLXFast.ScaledDotProductAttentionMaskMode
1216-
if let mask = mask {
1216+
if let mask {
12171217
localMask = mask
12181218
} else if isGlobal {
12191219
localMask = fullMask
@@ -1437,9 +1437,9 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
14371437
}
14381438

14391439
let embNorm: MLXArray
1440-
if let inputsEmbeds = inputsEmbeds {
1440+
if let inputsEmbeds {
14411441
embNorm = softEmbeddingNorm(inputsEmbeds)
1442-
} else if let inputIds = inputIds {
1442+
} else if let inputIds {
14431443
let hardEmb = embedding(inputIds - vocabOffset)
14441444
embNorm = hardEmbeddingNorm(hardEmb)
14451445
} else {
@@ -1490,7 +1490,7 @@ private func gemma3nAttentionWithCacheUpdate(
14901490
// Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
14911491
let (cachedKeys, cachedValues): (MLXArray, MLXArray)
14921492

1493-
if let cache = cache {
1493+
if let cache {
14941494
(cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
14951495
} else {
14961496
(cachedKeys, cachedValues) = (keys, values)
@@ -1792,7 +1792,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
17921792
}
17931793

17941794
// Process audio features
1795-
if let inputFeatures = inputFeatures, let inputFeaturesMask = inputFeaturesMask {
1795+
if let inputFeatures, let inputFeaturesMask = inputFeaturesMask {
17961796
let (audioFeatures, audioMask) = getAudioFeatures(inputFeatures, .!inputFeaturesMask)
17971797
let audioPaddingIds = MLXArray([config.vocabSize - 1]).expandedDimensions(axis: 0)
17981798
let audioPaddingEmbs = embedAudio.callAsFunction(audioPaddingIds, inputsEmbeds: nil)
@@ -1862,7 +1862,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18621862
) -> MLXArray {
18631863
let specialModalityMask: MLXArray
18641864

1865-
if let inputIds = inputIds {
1865+
if let inputIds {
18661866
specialModalityMask = expandedDimensions(inputIds .== tokenId, axis: -1)
18671867
} else {
18681868
// When inputIds is nil, create mask by comparing embeddings
@@ -2211,7 +2211,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
22112211
let expectedInputSuffix = featureDims + [numChannels]
22122212
assert(Array(x.shape.suffix(expectedInputSuffix.count)) == expectedInputSuffix)
22132213

2214-
if let mask = mask {
2214+
if let mask {
22152215
assert(mask.shape == Array(x.shape.prefix(2)))
22162216
assert(mask.dtype == .bool)
22172217
}
@@ -2221,7 +2221,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
22212221
let xCalc = x.asType(calcDtype)
22222222

22232223
let maskCalc: MLXArray
2224-
if let mask = mask {
2224+
if let mask {
22252225
let maskSuffixShape = Array(repeating: 1, count: expectedInputSuffix.count)
22262226
maskCalc = mask.reshaped(Array(mask.shape) + maskSuffixShape).asType(calcDtype)
22272227
} else {
@@ -2848,7 +2848,7 @@ private func rmsNorm2d(
28482848
let vMean = mean(v, axis: 1, keepDims: true)
28492849
var result = x * rsqrt(vMean + eps)
28502850

2851-
if let weight = weight {
2851+
if let weight {
28522852
let weightReshaped = weight.reshaped([1, -1, 1, 1])
28532853
result = result.asType(dtype) * weightReshaped
28542854
}
@@ -3061,7 +3061,7 @@ private class UniversalInvertedResidual: Module, UnaryLayer {
30613061
)
30623062

30633063
// Layer Scale
3064-
if let layerScaleInitValue = layerScaleInitValue {
3064+
if let layerScaleInitValue {
30653065
self._layerScale.wrappedValue = LayerScale2d(
30663066
dim: outChannels, initValues: layerScaleInitValue)
30673067
} else {
@@ -3420,7 +3420,7 @@ private class MobileAttention: Module, UnaryLayer {
34203420
}
34213421

34223422
// Layer scaling
3423-
if let layerScaleInitValue = layerScaleInitValue {
3423+
if let layerScaleInitValue {
34243424
self._layerScale.wrappedValue = LayerScale2d(
34253425
dim: outChannels, initValues: layerScaleInitValue)
34263426
} else {

0 commit comments

Comments
 (0)