@@ -538,33 +538,24 @@ private class Gemma3nAttention: Module {
538
538
values = vProj ( x) . reshaped ( hiddenShape)
539
539
values = vNorm ( values)
540
540
values = values. transposed ( 0 , 2 , 1 , 3 )
541
-
542
- if let cache = cache {
543
- ( keys, values) = cache. update ( keys: keys, values: values)
544
- }
545
541
}
546
542
543
+ // Repeat keys and values for multi-head attention
547
544
keys = repeated ( keys, count: repeats, axis: 1 )
548
545
values = repeated ( values, count: repeats, axis: 1 )
549
546
550
- var attnWeights = matmul ( queries, keys. swappedAxes ( 2 , 3 ) ) * scale
551
-
552
- if attnLogitSoftcapping > 0 {
553
- attnWeights = attnWeights / attnLogitSoftcapping
554
- attnWeights = tanh ( attnWeights)
555
- attnWeights = attnWeights * attnLogitSoftcapping
556
- }
557
-
558
- if case . array( let maskArray) = mask {
559
- let causalMask = maskArray [ 0 ... , ..< keys. shape [ 2 ] ]
560
- attnWeights = attnWeights + causalMask
561
- }
562
-
563
- attnWeights = softmax ( attnWeights. asType ( . float32) , axis: - 1 ) . asType ( queries. dtype)
564
-
565
- let output = matmul ( attnWeights, values)
566
- . transposed ( 0 , 2 , 1 , 3 )
567
- . reshaped ( inputShape + [ - 1 ] )
547
+ // Use custom attention function that supports both quantized cache and logit softcapping
548
+ let output = gemma3nAttentionWithCacheUpdate (
549
+ queries: queries,
550
+ keys: keys,
551
+ values: values,
552
+ cache: cache,
553
+ scale: scale,
554
+ attnLogitSoftcapping: attnLogitSoftcapping,
555
+ mask: mask ?? . none
556
+ )
557
+ . transposed ( 0 , 2 , 1 , 3 )
558
+ . reshaped ( inputShape + [ - 1 ] )
568
559
569
560
return oProj ( output)
570
561
}
@@ -1308,6 +1299,72 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
1308
1299
1309
1300
// MARK: - Helper Functions
1310
1301
1302
+ // MARK: - Custom Attention for Gemma3n with Logit Softcapping
1303
+
1304
+ /// Custom attention function for Gemma3n that supports:
1305
+ /// - Logit softcapping (applied before softmax)
1306
+ /// - Standard KV cache support
1307
+ /// - Exact alignment with Python implementation
1308
+ ///
1309
+ /// TODO: Quantized KV Cache Integration
1310
+ /// Action items for adding quantized cache support:
1311
+ /// 1. Add QuantizedKVCache detection: `if let quantizedKVCache = cache as? QuantizedKVCache`
1312
+ /// 2. Use quantizedKVCache.updateQuantized(keys: keys, values: values) for cache update
1313
+ /// 3. Implement manual quantized attention computation with logit softcapping:
1314
+ /// - Cannot use quantizedScaledDotProductAttention directly (no softcapping support)
1315
+ /// - Need to manually compute: matmul(queries, dequantized_keys) with softcapping
1316
+ /// - May require dequantization of keys for logit softcapping application
1317
+ /// 4. Consider performance trade-offs:
1318
+ /// - Manual dequantization vs quantized attention benefits
1319
+ /// - Might need hybrid approach or dedicated quantized+softcapping function
1320
+ /// 5. Test with QuantizedKVCache to ensure numerical accuracy matches Python
1321
+ /// 6. Update documentation and examples
1322
+ private func gemma3nAttentionWithCacheUpdate(
1323
+ queries: MLXArray ,
1324
+ keys: MLXArray ,
1325
+ values: MLXArray ,
1326
+ cache: KVCache ? ,
1327
+ scale: Float ,
1328
+ attnLogitSoftcapping: Float ,
1329
+ mask: MLXFast . ScaledDotProductAttentionMaskMode = . none
1330
+ ) -> MLXArray {
1331
+ // Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
1332
+ let ( cachedKeys, cachedValues) : ( MLXArray , MLXArray )
1333
+
1334
+ if let cache = cache {
1335
+ ( cachedKeys, cachedValues) = cache. update ( keys: keys, values: values)
1336
+ } else {
1337
+ ( cachedKeys, cachedValues) = ( keys, values)
1338
+ }
1339
+
1340
+ // Manual attention computation to support logit softcapping
1341
+ // This matches the Python implementation exactly:
1342
+ // attn_weights = mx.matmul(queries, keys.swapaxes(2, 3)) * self.scale
1343
+ var attnWeights = matmul ( queries, cachedKeys. swappedAxes ( 2 , 3 ) ) * scale
1344
+
1345
+ // Apply logit softcapping if enabled (matches Python)
1346
+ // if self.attn_logit_softcapping is not None and self.attn_logit_softcapping > 0:
1347
+ if attnLogitSoftcapping > 0 {
1348
+ attnWeights = attnWeights / attnLogitSoftcapping
1349
+ attnWeights = tanh ( attnWeights)
1350
+ attnWeights = attnWeights * attnLogitSoftcapping
1351
+ }
1352
+
1353
+ // Apply mask if provided (matches Python)
1354
+ // if mask is not None: causal_mask = mask[:, : keys.shape[-2]]
1355
+ if case . array( let maskArray) = mask {
1356
+ let causalMask = maskArray [ 0 ... , ..< cachedKeys. shape [ 2 ] ]
1357
+ attnWeights = attnWeights + causalMask
1358
+ }
1359
+
1360
+ // Apply softmax and compute output (matches Python)
1361
+ // attn_weights = mx.softmax(attn_weights.astype(mx.float32), axis=-1).astype(queries.dtype)
1362
+ attnWeights = softmax ( attnWeights. asType ( . float32) , axis: - 1 ) . asType ( queries. dtype)
1363
+
1364
+ // output = mx.matmul(attn_weights, values)
1365
+ return matmul ( attnWeights, cachedValues)
1366
+ }
1367
+
1311
1368
private func bicubicInterpolate( _ x: MLXArray , to targetSize: ( Int , Int ) , alignCorners: Bool = false ) -> MLXArray {
1312
1369
// TODO: This implementation uses nested loops and sequential MLX operations, which is much slower
1313
1370
// than the Python version that uses mx.fast.metal_kernel() for parallel GPU computation.
0 commit comments