@@ -1146,18 +1146,18 @@ private class Gemma3Model: Module {
1146
1146
perLayerInputs: MLXArray ? = nil
1147
1147
) -> MLXArray {
1148
1148
var h : MLXArray
1149
- if let inputsEmbeds = inputsEmbeds {
1149
+ if let inputsEmbeds {
1150
1150
h = inputsEmbeds
1151
- } else if let inputs = inputs {
1151
+ } else if let inputs {
1152
1152
h = embedTokens ( inputs)
1153
1153
} else {
1154
1154
fatalError ( " Either inputs or inputsEmbeds must be provided " )
1155
1155
}
1156
1156
1157
1157
let perLayerInputsProcessed : MLXArray
1158
- if let perLayerInputs = perLayerInputs {
1158
+ if let perLayerInputs {
1159
1159
perLayerInputsProcessed = perLayerInputs
1160
- } else if let inputs = inputs {
1160
+ } else if let inputs {
1161
1161
perLayerInputsProcessed = getPerLayerInputs ( inputs)
1162
1162
} else {
1163
1163
fatalError ( " Cannot generate per layer inputs without input ids " )
@@ -1213,7 +1213,7 @@ private class Gemma3Model: Module {
1213
1213
== " global_attention "
1214
1214
1215
1215
let localMask : MLXFast . ScaledDotProductAttentionMaskMode
1216
- if let mask = mask {
1216
+ if let mask {
1217
1217
localMask = mask
1218
1218
} else if isGlobal {
1219
1219
localMask = fullMask
@@ -1437,9 +1437,9 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
1437
1437
}
1438
1438
1439
1439
let embNorm : MLXArray
1440
- if let inputsEmbeds = inputsEmbeds {
1440
+ if let inputsEmbeds {
1441
1441
embNorm = softEmbeddingNorm ( inputsEmbeds)
1442
- } else if let inputIds = inputIds {
1442
+ } else if let inputIds {
1443
1443
let hardEmb = embedding ( inputIds - vocabOffset)
1444
1444
embNorm = hardEmbeddingNorm ( hardEmb)
1445
1445
} else {
@@ -1490,7 +1490,7 @@ private func gemma3nAttentionWithCacheUpdate(
1490
1490
// Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
1491
1491
let ( cachedKeys, cachedValues) : ( MLXArray , MLXArray )
1492
1492
1493
- if let cache = cache {
1493
+ if let cache {
1494
1494
( cachedKeys, cachedValues) = cache. update ( keys: keys, values: values)
1495
1495
} else {
1496
1496
( cachedKeys, cachedValues) = ( keys, values)
@@ -1792,7 +1792,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1792
1792
}
1793
1793
1794
1794
// Process audio features
1795
- if let inputFeatures = inputFeatures , let inputFeaturesMask = inputFeaturesMask {
1795
+ if let inputFeatures, let inputFeaturesMask = inputFeaturesMask {
1796
1796
let ( audioFeatures, audioMask) = getAudioFeatures ( inputFeatures, .! inputFeaturesMask)
1797
1797
let audioPaddingIds = MLXArray ( [ config. vocabSize - 1 ] ) . expandedDimensions ( axis: 0 )
1798
1798
let audioPaddingEmbs = embedAudio. callAsFunction ( audioPaddingIds, inputsEmbeds: nil )
@@ -1862,7 +1862,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
1862
1862
) -> MLXArray {
1863
1863
let specialModalityMask : MLXArray
1864
1864
1865
- if let inputIds = inputIds {
1865
+ if let inputIds {
1866
1866
specialModalityMask = expandedDimensions ( inputIds .== tokenId, axis: - 1 )
1867
1867
} else {
1868
1868
// When inputIds is nil, create mask by comparing embeddings
@@ -2211,7 +2211,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
2211
2211
let expectedInputSuffix = featureDims + [ numChannels]
2212
2212
assert ( Array ( x. shape. suffix ( expectedInputSuffix. count) ) == expectedInputSuffix)
2213
2213
2214
- if let mask = mask {
2214
+ if let mask {
2215
2215
assert ( mask. shape == Array ( x. shape. prefix ( 2 ) ) )
2216
2216
assert ( mask. dtype == . bool)
2217
2217
}
@@ -2221,7 +2221,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
2221
2221
let xCalc = x. asType ( calcDtype)
2222
2222
2223
2223
let maskCalc : MLXArray
2224
- if let mask = mask {
2224
+ if let mask {
2225
2225
let maskSuffixShape = Array ( repeating: 1 , count: expectedInputSuffix. count)
2226
2226
maskCalc = mask. reshaped ( Array ( mask. shape) + maskSuffixShape) . asType ( calcDtype)
2227
2227
} else {
@@ -2848,7 +2848,7 @@ private func rmsNorm2d(
2848
2848
let vMean = mean ( v, axis: 1 , keepDims: true )
2849
2849
var result = x * rsqrt( vMean + eps)
2850
2850
2851
- if let weight = weight {
2851
+ if let weight {
2852
2852
let weightReshaped = weight. reshaped ( [ 1 , - 1 , 1 , 1 ] )
2853
2853
result = result. asType ( dtype) * weightReshaped
2854
2854
}
@@ -3061,7 +3061,7 @@ private class UniversalInvertedResidual: Module, UnaryLayer {
3061
3061
)
3062
3062
3063
3063
// Layer Scale
3064
- if let layerScaleInitValue = layerScaleInitValue {
3064
+ if let layerScaleInitValue {
3065
3065
self . _layerScale. wrappedValue = LayerScale2d (
3066
3066
dim: outChannels, initValues: layerScaleInitValue)
3067
3067
} else {
@@ -3420,7 +3420,7 @@ private class MobileAttention: Module, UnaryLayer {
3420
3420
}
3421
3421
3422
3422
// Layer scaling
3423
- if let layerScaleInitValue = layerScaleInitValue {
3423
+ if let layerScaleInitValue {
3424
3424
self . _layerScale. wrappedValue = LayerScale2d (
3425
3425
dim: outChannels, initValues: layerScaleInitValue)
3426
3426
} else {
0 commit comments