Skip to content

Commit 2d4fcad

Browse files
committed
Custom attention with cache update
1 parent 056599a commit 2d4fcad

File tree

1 file changed

+79
-22
lines changed

1 file changed

+79
-22
lines changed

Libraries/MLXVLM/Models/Gemma3n.swift

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -538,33 +538,24 @@ private class Gemma3nAttention: Module {
538538
values = vProj(x).reshaped(hiddenShape)
539539
values = vNorm(values)
540540
values = values.transposed(0, 2, 1, 3)
541-
542-
if let cache = cache {
543-
(keys, values) = cache.update(keys: keys, values: values)
544-
}
545541
}
546542

543+
// Repeat keys and values for multi-head attention
547544
keys = repeated(keys, count: repeats, axis: 1)
548545
values = repeated(values, count: repeats, axis: 1)
549546

550-
var attnWeights = matmul(queries, keys.swappedAxes(2, 3)) * scale
551-
552-
if attnLogitSoftcapping > 0 {
553-
attnWeights = attnWeights / attnLogitSoftcapping
554-
attnWeights = tanh(attnWeights)
555-
attnWeights = attnWeights * attnLogitSoftcapping
556-
}
557-
558-
if case .array(let maskArray) = mask {
559-
let causalMask = maskArray[0..., ..<keys.shape[2]]
560-
attnWeights = attnWeights + causalMask
561-
}
562-
563-
attnWeights = softmax(attnWeights.asType(.float32), axis: -1).asType(queries.dtype)
564-
565-
let output = matmul(attnWeights, values)
566-
.transposed(0, 2, 1, 3)
567-
.reshaped(inputShape + [-1])
547+
// Use custom attention function that supports both quantized cache and logit softcapping
548+
let output = gemma3nAttentionWithCacheUpdate(
549+
queries: queries,
550+
keys: keys,
551+
values: values,
552+
cache: cache,
553+
scale: scale,
554+
attnLogitSoftcapping: attnLogitSoftcapping,
555+
mask: mask ?? .none
556+
)
557+
.transposed(0, 2, 1, 3)
558+
.reshaped(inputShape + [-1])
568559

569560
return oProj(output)
570561
}
@@ -1308,6 +1299,72 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
13081299

13091300
// MARK: - Helper Functions
13101301

1302+
// MARK: - Custom Attention for Gemma3n with Logit Softcapping
1303+
1304+
/// Custom attention function for Gemma3n that supports:
1305+
/// - Logit softcapping (applied before softmax)
1306+
/// - Standard KV cache support
1307+
/// - Exact alignment with Python implementation
1308+
///
1309+
/// TODO: Quantized KV Cache Integration
1310+
/// Action items for adding quantized cache support:
1311+
/// 1. Add QuantizedKVCache detection: `if let quantizedKVCache = cache as? QuantizedKVCache`
1312+
/// 2. Use quantizedKVCache.updateQuantized(keys: keys, values: values) for cache update
1313+
/// 3. Implement manual quantized attention computation with logit softcapping:
1314+
/// - Cannot use quantizedScaledDotProductAttention directly (no softcapping support)
1315+
/// - Need to manually compute: matmul(queries, dequantized_keys) with softcapping
1316+
/// - May require dequantization of keys for logit softcapping application
1317+
/// 4. Consider performance trade-offs:
1318+
/// - Manual dequantization vs quantized attention benefits
1319+
/// - Might need hybrid approach or dedicated quantized+softcapping function
1320+
/// 5. Test with QuantizedKVCache to ensure numerical accuracy matches Python
1321+
/// 6. Update documentation and examples
1322+
private func gemma3nAttentionWithCacheUpdate(
1323+
queries: MLXArray,
1324+
keys: MLXArray,
1325+
values: MLXArray,
1326+
cache: KVCache?,
1327+
scale: Float,
1328+
attnLogitSoftcapping: Float,
1329+
mask: MLXFast.ScaledDotProductAttentionMaskMode = .none
1330+
) -> MLXArray {
1331+
// Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
1332+
let (cachedKeys, cachedValues): (MLXArray, MLXArray)
1333+
1334+
if let cache = cache {
1335+
(cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
1336+
} else {
1337+
(cachedKeys, cachedValues) = (keys, values)
1338+
}
1339+
1340+
// Manual attention computation to support logit softcapping
1341+
// This matches the Python implementation exactly:
1342+
// attn_weights = mx.matmul(queries, keys.swapaxes(2, 3)) * self.scale
1343+
var attnWeights = matmul(queries, cachedKeys.swappedAxes(2, 3)) * scale
1344+
1345+
// Apply logit softcapping if enabled (matches Python)
1346+
// if self.attn_logit_softcapping is not None and self.attn_logit_softcapping > 0:
1347+
if attnLogitSoftcapping > 0 {
1348+
attnWeights = attnWeights / attnLogitSoftcapping
1349+
attnWeights = tanh(attnWeights)
1350+
attnWeights = attnWeights * attnLogitSoftcapping
1351+
}
1352+
1353+
// Apply mask if provided (matches Python)
1354+
// if mask is not None: causal_mask = mask[:, : keys.shape[-2]]
1355+
if case .array(let maskArray) = mask {
1356+
let causalMask = maskArray[0..., ..<cachedKeys.shape[2]]
1357+
attnWeights = attnWeights + causalMask
1358+
}
1359+
1360+
// Apply softmax and compute output (matches Python)
1361+
// attn_weights = mx.softmax(attn_weights.astype(mx.float32), axis=-1).astype(queries.dtype)
1362+
attnWeights = softmax(attnWeights.asType(.float32), axis: -1).asType(queries.dtype)
1363+
1364+
// output = mx.matmul(attn_weights, values)
1365+
return matmul(attnWeights, cachedValues)
1366+
}
1367+
13111368
private func bicubicInterpolate(_ x: MLXArray, to targetSize: (Int, Int), alignCorners: Bool = false) -> MLXArray {
13121369
// TODO: This implementation uses nested loops and sequential MLX operations, which is much slower
13131370
// than the Python version that uses mx.fast.metal_kernel() for parallel GPU computation.

0 commit comments

Comments
 (0)