From 4eefea962f5310c9a85abb98d64fc6f5264617d4 Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Thu, 6 Feb 2025 13:40:33 -0800 Subject: [PATCH 1/5] support for Qwen2.5-VL --- Libraries/MLXVLM/Models/Qwen25VL.swift | 1213 ++++++++++++++++++++++++ Libraries/MLXVLM/VLMModelFactory.swift | 9 + 2 files changed, 1222 insertions(+) create mode 100644 Libraries/MLXVLM/Models/Qwen25VL.swift diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift new file mode 100644 index 00000000..c8a68035 --- /dev/null +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -0,0 +1,1213 @@ +// +// Qwen25VL.swift +// mlx-swift-examples +// +// Created by Sachin Desai on 2/1/25. +// + +// port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/qwen2_5_vl + +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Common + +/// Rotates half the hidden dims of the input +private func rotateHalf(_ x: MLXArray) -> MLXArray { + let index = x.dim(-1) / 2 + let x1 = x[.ellipsis, 0 ..< index] + let x2 = x[.ellipsis, index...] + return concatenated([-x2, x1], axis: -1) +} + +// MARK: - Language + +private enum Language { + static private func applyMultimodalRotaryPositionEmbedding( + q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, + positionIds: MLXArray, mropeSection: [Int] + ) -> (MLXArray, MLXArray) { + var cos = cos[positionIds] + var sin = sin[positionIds] + + cos = + concatenated( + // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] + split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + sin = + concatenated( + split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + // Apply rotary embedding + let qEmbed = (q * cos) + (rotateHalf(q) * sin) + let kEmbed = (k * cos) + (rotateHalf(k) * sin) + return (qEmbed, kEmbed) + } + + fileprivate class Attention: Module { + let heads: Int + let kvHeads: Int + let headDim: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + let dim = args.hiddenSize + self.heads = args.attentionHeads + self.kvHeads = args.kvHeads + self.headDim = dim / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + self._rotaryEmbedding.wrappedValue = RoPE( + dimensions: headDim, + traditional: args.ropeTraditional, + base: args.ropeTheta + ) + } + + public func callAsFunction( + _ x: MLXArray, + mask: MLXArray? = nil, + cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + + let offset = cache?.offset ?? 0 + let mask = mask?[0..., 0 ..< keys.dim(-2)] + + queries = rotaryEmbedding(queries, offset: offset) + keys = rotaryEmbedding(keys, offset: offset) + + if let cache { + (keys, values) = cache.update(keys: keys, values: values) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } + } + + fileprivate class MLP: Module { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } + + fileprivate class Qwen25VLDecoderLayer: Module { + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return out + } + } + + fileprivate class Qwen25Model: Module { + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [Qwen25VLDecoderLayer] + fileprivate let norm: RMSNorm + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + Qwen25VLDecoderLayer(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> MLXArray { + var h: MLXArray + if let inputEmbedding { + h = inputEmbedding + } else if let inputs { + h = embedTokens(inputs) + } else { + fatalError("one of inputs or inputEmbedding must be non-nil") + } + + let mask = createAttentionMask(h: h, cache: cache) + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } + } + + fileprivate class LanguageModel: Module, KVCacheDimensionProvider { + @ModuleInfo var model: Qwen25Model + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + var kvHeads: [Int] + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + self.model = Qwen25Model(args) + + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> LMOutput { + var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding) + if let lmHead { + out = lmHead(out) + } else { + out = model.embedTokens.asLinear(out) + } + return LMOutput(logits: out) + } + } +} + +// MARK: - Vision + +private enum Vision { + static fileprivate func applyMultimodalRotaryPositionEmbedding( + _ tensor: MLXArray, freqs: MLXArray + ) -> MLXArray { + var cos = cos(freqs) + var sin = sin(freqs) + + cos = expandedDimensions(cos, axis: 1) + cos = tiled(cos, repetitions: [1, 1, 2]) + cos = expandedDimensions(cos, axis: 0) + + sin = expandedDimensions(sin, axis: 1) + sin = tiled(sin, repetitions: [1, 1, 2]) + sin = expandedDimensions(sin, axis: 0) + + let output = (tensor * cos) + (rotateHalf(tensor) * sin) + return output.asType(tensor.dtype) + } + + fileprivate class VisionRotaryEmbedding: Module { + let dimensions: Int + let theta: Float + let inverseFreq: MLXArray + + init(dimensions: Int, theta: Float = 10000.0) { + self.dimensions = dimensions + self.theta = theta + let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions + self.inverseFreq = 1.0 / pow(theta, p) + } + + func callAsFunction(_ sequenceLength: Int) -> MLXArray { + let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) + return outer(seq, inverseFreq) + } + } + + fileprivate class PatchEmbed: Module { + @ModuleInfo var proj: Conv3d + let patchSize: Int + let temporalPatchSize: Int + let inChannels: Int + let embedDimensions: Int + + init( + patchSize: Int = 14, + temporalPatchSize: Int = 2, + inChannels: Int = 3, + embedDimensions: Int = 1152 + ) { + self.patchSize = patchSize + self.temporalPatchSize = temporalPatchSize + self.inChannels = inChannels + self.embedDimensions = embedDimensions + + let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) + self._proj.wrappedValue = Conv3d( + inputChannels: inChannels, + outputChannels: embedDimensions, + kernelSize: kernelSize, + stride: kernelSize, + bias: false + ) + } + + func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { + var hiddenStates = hiddenStates.reshaped( + -1, inChannels, temporalPatchSize, patchSize, patchSize + ).movedAxis(source: 1, destination: 4) + + hiddenStates = proj(hiddenStates) + hiddenStates = hiddenStates.reshaped(-1, embedDimensions) + return hiddenStates + } + } + + fileprivate class PatchMerger: Module { + let hiddenSize: Int + @ModuleInfo(key: "ln_q") var layerNormQ: RMSNorm + @ModuleInfo var mlp: (Linear, GELU, Linear) + + init(dimensions: Int, contextDimensions: Int, spatialMergeSize: Int = 2) { + self.hiddenSize = contextDimensions * (spatialMergeSize * spatialMergeSize) + self._layerNormQ.wrappedValue = RMSNorm(dimensions: contextDimensions, eps: 1e-6) + self.mlp = ( + Linear(hiddenSize, hiddenSize), + GELU(), + Linear(hiddenSize, dimensions) + ) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = layerNormQ(x).reshaped(-1, hiddenSize) + x = mlp.0(x) + x = mlp.1(x) + x = mlp.2(x) + return x + } + } + + fileprivate class Attention: Module { + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "qkv") var qkv: Linear + @ModuleInfo(key: "proj") var proj: Linear + + init(dim: Int, numHeads: Int = 16) { + self.numHeads = numHeads + let headDim = dim / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._qkv.wrappedValue = Linear(dim, dim * 3, bias: true) + self._proj.wrappedValue = Linear(dim, dim) + } + + func callAsFunction( + _ x: MLXArray, + cuSeqlens: MLXArray, + rotaryPositionEmbedding: MLXArray? = nil + ) -> MLXArray { + let seqLength = x.dim(0) + let qkv = qkv(x).reshaped(seqLength, 3, numHeads, -1).transposed(1, 0, 2, 3) + let (q, k, v) = ( + qkv[0].expandedDimensions(axis: 0), qkv[1].expandedDimensions(axis: 0), + qkv[2].expandedDimensions(axis: 0) + ) + + var queries = q + var keys = k + + if let rotaryPositionEmbedding { + queries = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) + keys = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries.transposed(0, 2, 1, 3), + keys: keys.transposed(0, 2, 1, 3), + values: v.transposed(0, 2, 1, 3), + scale: scale, + mask: nil + ) + .transposed(0, 2, 1, 3) + .reshaped(seqLength, -1) + + return proj(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } + + fileprivate class Qwen25VLVisionBlock: Module { + @ModuleInfo var norm1: RMSNorm + @ModuleInfo var norm2: RMSNorm + @ModuleInfo(key: "attn") var attention: Attention + @ModuleInfo var mlp: MLP + + init(_ config: Qwen25VLConfiguration.VisionConfiguration) { + self.norm1 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) + self.norm2 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) + + self._attention.wrappedValue = Attention( + dim: config.hiddenSize, + numHeads: config.numHeads + ) + + self.mlp = MLP( + dimensions: config.hiddenSize, + hiddenDimensions: config.intermediateSize + ) + } + + func callAsFunction( + _ hiddenStates: MLXArray, + cuSeqlens: MLXArray, + rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + var hiddenStates = + hiddenStates + + attention( + norm1(hiddenStates), + cuSeqlens: cuSeqlens, + rotaryPositionEmbedding: rotaryPositionEmbedding + ) + + hiddenStates = hiddenStates + mlp(norm2(hiddenStates)) + return hiddenStates + } + } + + fileprivate class VisionModel: Module { + @ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed + @ModuleInfo(key: "merger") var merger: PatchMerger + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding + @ModuleInfo(key: "blocks") var blocks: [Qwen25VLVisionBlock] + + let windowSize: Int + let patchSize: Int + let spatialMergeSize: Int + let spatialMergeUnit: Int + let fullAttBlockIndexes: [Int] + + init(_ config: Qwen25VLConfiguration.VisionConfiguration) { + self.windowSize = config.windowSize + self.patchSize = config.patchSize + self.spatialMergeSize = config.spatialMergeSize + self.fullAttBlockIndexes = config.fullAttBlockIndexes + + self.spatialMergeUnit = spatialMergeSize * spatialMergeSize + + self._patchEmbed.wrappedValue = PatchEmbed( + patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize, + inChannels: config.inChannels, + embedDimensions: config.hiddenSize + ) + + let headDim = config.hiddenSize / config.numHeads + self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding( + dimensions: headDim / 2 + ) + + self._blocks.wrappedValue = (0 ..< config.depth).map { _ in + Qwen25VLVisionBlock(config) + } + + self._merger.wrappedValue = PatchMerger( + dimensions: config.outHiddenSize, contextDimensions: config.hiddenSize + ) + } + + func callAsFunction( + _ hiddenStates: MLXArray, + gridThw: [THW], + outputHiddenStates: Bool = false + ) -> MLXArray { + var hiddenStates = patchEmbed(hiddenStates) + var rotaryPosEmb = getRotaryPosEmb(gridThw) + var (windowIndex, cuWindowSeqlens) = getWindowIndex(gridThw) + + let seqlensArray = cuWindowSeqlens.asArray(Int.self) + var seen = Set() + var idx: [Int32] = [] + + for (i, x) in seqlensArray.enumerated() { + if !seen.contains(x) { + seen.insert(x) + idx.append(Int32(i)) + } + } + + let idx1 = MLXArray(idx) + cuWindowSeqlens = cuWindowSeqlens[idx1] + + let seqLen = hiddenStates.dim(0) + hiddenStates = hiddenStates.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) + hiddenStates = hiddenStates[windowIndex, 0..., 0...] + hiddenStates = hiddenStates.reshaped(seqLen, -1) + + rotaryPosEmb = rotaryPosEmb.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) + rotaryPosEmb = rotaryPosEmb[windowIndex, 0..., 0...] + rotaryPosEmb = rotaryPosEmb.reshaped(seqLen, -1) + + // Assuming grid_thw has shape (batch_size, 3) + let batchSize = gridThw.count + + var cuSeqlens: [MLXArray] = [] + for row in gridThw { + let (gridT, gridH, gridW) = row.values + let seqLen = gridH * gridW + let repeats = gridT + + // Create array with repeated values + let repeatedSeq = MLXArray.full([repeats], values: MLXArray(seqLen)) + + cuSeqlens.append(repeatedSeq) + } + + let cuSeqlensPadded = padded( + cumsum(concatenated(cuSeqlens)), + width: .init((1, 0)), + mode: .constant, + value: MLXArray(0) + ) + + // Window processing + for (layerNum, block) in blocks.enumerated() { + let cuSeqlensNow = + fullAttBlockIndexes.contains(layerNum) ? cuSeqlensPadded : cuWindowSeqlens + hiddenStates = block( + hiddenStates, + cuSeqlens: cuSeqlensNow, + rotaryPositionEmbedding: rotaryPosEmb + ) + } + + hiddenStates = merger(hiddenStates) + let reverseIndices = argSort(windowIndex, axis: 0) + hiddenStates = hiddenStates[reverseIndices, 0...] + + return hiddenStates + } + + private func getRotaryPosEmb(_ gridThw: [THW]) -> MLXArray { + var posIds = [MLXArray]() + + for row in gridThw { + let (t, h, w) = row.values + + // Create and process horizontal position IDs + var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1) + hposIds = repeated(hposIds, count: w, axis: 1) + hposIds = hposIds.reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + hposIds = hposIds.transposed(0, 2, 1, 3) + hposIds = hposIds.flattened() + + // Create and process vertical position IDs + var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0) + wposIds = repeated(wposIds, count: h, axis: 0) + wposIds = wposIds.reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + wposIds = wposIds.transposed(0, 2, 1, 3) + wposIds = wposIds.flattened() + + // Stack and tile position IDs + let stackedPosIds = stacked([hposIds, wposIds], axis: -1) + posIds.append(tiled(stackedPosIds, repetitions: [t, 1])) + } + + let indices = concatenated(posIds, axis: 0) + let maxGridSize = gridThw.lazy.map({ max($0.h, $0.w) }).max() ?? 0 + let rotaryPosEmbFull = rotaryPositionEmbedding(maxGridSize) + let rotaryPosEmb = rotaryPosEmbFull[indices] + + return rotaryPosEmb.reshaped(indices.dim(0), -1) + } + + private func getWindowIndex(_ gridThw: [THW]) -> (MLXArray, MLXArray) { + var windowIndex = [MLXArray]() + var cuWindowSeqlens = [0] + var windowIndexId = [0] + let vitMergerWindowSize = windowSize / spatialMergeSize / patchSize + + for row in gridThw { + let (gridT, gridH, gridW) = row.values + let llmGridH = gridH / spatialMergeSize + let llmGridW = gridW / spatialMergeSize + + // Create initial index array + let index = MLXArray(0 ..< (gridT * llmGridH * llmGridW)).reshaped( + gridT, llmGridH, llmGridW) + + // Calculate padding and window dimensions + let padH = vitMergerWindowSize - llmGridH % vitMergerWindowSize + let padW = vitMergerWindowSize - llmGridW % vitMergerWindowSize + let numWindowsH = (llmGridH + padH) / vitMergerWindowSize + let numWindowsW = (llmGridW + padW) / vitMergerWindowSize + + var indexPadded = padded( + index, + widths: [0, .init((0, padH)), .init((0, padW))], + mode: .constant, + value: MLXArray(-100, dtype: index.dtype)) + + // Reshape and transpose for window creation + indexPadded = indexPadded.reshaped( + gridT, + numWindowsH, + vitMergerWindowSize, + numWindowsW, + vitMergerWindowSize + ) + + indexPadded = indexPadded.transposed(0, 1, 3, 2, 4) + indexPadded = indexPadded.reshaped( + gridT, + numWindowsH * numWindowsW, + vitMergerWindowSize, + vitMergerWindowSize + ) + + // Process sequence lengths and indices + let seqlens = sum(indexPadded .!= -100, axes: [2, 3]).reshaped(-1) + indexPadded = indexPadded.reshaped(-1) + + var indices = [Int]() + for (i, v) in indexPadded.asArray(Int.self).enumerated() { + if v != -100 { + indices.append(v) + } + } + + let indexNew = MLXArray(indices) + + // Update window index and cumulative sequence lengths + windowIndex.append(indexNew + windowIndexId) + let cuSeqlensTmp = + cumsum(seqlens, axis: 0) * spatialMergeUnit + (cuWindowSeqlens.last ?? 0) + + cuWindowSeqlens.append(contentsOf: cuSeqlensTmp.asArray(Int.self)) + windowIndexId += [gridT * llmGridH * llmGridW] + } + + // Create final arrays + let finalWindowIndex = concatenated(windowIndex, axis: 0) + let finalCuWindowSeqlens = MLXArray(cuWindowSeqlens) + return (finalWindowIndex, finalCuWindowSeqlens) + } + + private func getCuSeqlens(_ gridThw: [THW]) -> MLXArray { + var cuSeqlens = [MLXArray]() + + // Calculate cumulative sequence lengths for each item in batch + for row in gridThw { + let seqLen = row.h * row.w + let repeatedLen = repeated(MLXArray(seqLen), count: row.t, axis: 0) + cuSeqlens.append(repeatedLen) + } + + // Concatenate and process all sequence lengths + var result = concatenated(cuSeqlens, axis: 0) + result = cumsum(result.asType(.int32), axis: 0) + + var r = padded(result, width: .init((1, 0))) + + // Add leading zero for offset calculation + return r + } + + private func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4 && array.ndim != 5 { + return false + } + + if array.dim(-1) == 3 { + return true + } + + let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) + return outChannels >= kH && outChannels >= kW && kH == kW + } + + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitizedWeights = [String: MLXArray]() + + for (k, v) in weights { + if k.contains("position_id") { + // Remove unused position_ids + continue + } else if k.contains("patch_embed.proj.weight") { + // PyTorch conv2d weight tensors have shape: + // [B, out_channels, in_channels, kH, KW] + // MLX conv2d expects the weight be of shape: + // [B, out_channels, kH, KW, in_channels] + if isMLXWeight(v) { + sanitizedWeights[k] = v + } else { + sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1) + } + } else { + sanitizedWeights[k] = v + } + } + + return sanitizedWeights + } + } +} + +public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { + @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel + @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel + + public let config: Qwen25VLConfiguration + + public var vocabularySize: Int { config.baseConfiguration.vocabularySize } + public var kvHeads: [Int] { languageModel.kvHeads } + + public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers { + languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } + + public init(_ config: Qwen25VLConfiguration) { + self.config = config + self._visionModel.wrappedValue = Vision.VisionModel(config.visionConfiguration) + self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) + } + + private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?) + -> MLXArray + { + guard let pixelValues, let gridThw else { + return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis]) + } + + // Get input embeddings from language model + let inputEmbeds = languageModel.model.embedTokens(inputIds) + + // Get hidden states from vision model + var hiddenStates = self.visionModel(pixelValues, gridThw: gridThw) + + if hiddenStates.ndim == 2 { + hiddenStates = hiddenStates[.newAxis, 0..., 0...] + } + + // Merge input IDs with image features + return mergeInputIdsWithImageFeatures( + inputIds: inputIds, + inputEmbeds: inputEmbeds, + imageFeatures: hiddenStates + ) + } + + private func mergeInputIdsWithImageFeatures( + inputIds: MLXArray, + inputEmbeds: MLXArray, + imageFeatures: MLXArray + ) -> MLXArray { + let imageTokenIndex = config.baseConfiguration.imageTokenId + let videoTokenIndex = config.baseConfiguration.videoTokenId + + var imageIndices = [Int]() + for (i, v) in inputIds.asArray(Int.self).enumerated() { + if v == imageTokenIndex { + imageIndices.append(i) + } + } + + if imageIndices.isEmpty { + for (i, v) in inputIds.asArray(Int.self).enumerated() { + if v == videoTokenIndex { + imageIndices.append(i) + } + } + } + + inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures + return inputEmbeds + } + + public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws + -> PrepareResult + { + let dtype = visionModel.patchEmbed.proj.weight.dtype + + let imageGridThw = input.image?.imageGridThw + let imagePixels = input.image?.pixels.asType(dtype) + + let videoGridThw = input.video?.videoGridThw + let videoPixels = input.video?.pixels.asType(dtype) + + let gridThw: [THW]? + let pixels: MLXArray? + + if videoGridThw == nil { + gridThw = imageGridThw + pixels = imagePixels + } else { + gridThw = videoGridThw + pixels = videoPixels + } + + let inputEmbeddings = self.inputEmbeddings( + inputIds: input.text.tokens, + pixelValues: pixels, + gridThw: gridThw + ) + + let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) + return .logits(result) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { + languageModel(inputs, cache: cache).logits + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + visionModel.sanitize( + weights: Dictionary( + uniqueKeysWithValues: weights.map { key, value in + var key = key + if !key.contains("vision_tower") { + key = key.replacingOccurrences(of: "visual", with: "vision_tower") + } + if !key.contains("language_model") { + key = key.replacingOccurrences(of: "model", with: "language_model.model") + key = key.replacingOccurrences( + of: "lm_head", with: "language_model.lm_head") + } + return (key, value) + }) + ) + } +} + +// MARK: - Configuration + +public struct Qwen25VLConfiguration: Codable, Sendable { + public struct TextConfiguration: Codable, Sendable { + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + private let _rmsNormEps: Float? + public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } + public let vocabularySize: Int + public let kvHeads: Int + private let _maxPositionEmbeddings: Int? + public var maxPositionEmbeddings: Int { _maxPositionEmbeddings ?? 128000 } + private let _ropeTheta: Float? + public var ropeTheta: Float { _ropeTheta ?? 1_000_000 } + private let _ropeTraditional: Bool? + public var ropeTraditional: Bool { _ropeTraditional ?? false } + public let ropeScaling: [String: StringOrNumber]? + private let _tieWordEmbeddings: Bool? + public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true } + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case _rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case _maxPositionEmbeddings = "max_position_embeddings" + case _ropeTheta = "rope_theta" + case _ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case _tieWordEmbeddings = "tie_word_embeddings" + } + } + + public struct VisionConfiguration: Codable, Sendable { + public let depth: Int + public let hiddenSize: Int + public let intermediateSize: Int + public let outHiddenSize: Int + public let numHeads: Int + public let patchSize: Int + public let _inChannels: Int? + public var inChannels: Int { _inChannels ?? 3 } + public let _layerNormEps: Float? + public var layerNormEps: Float { _layerNormEps ?? 1e-6 } + public let spatialPatchSize: Int + public let spatialMergeSize: Int + public let temporalPatchSize: Int + public let windowSize: Int + public let fullAttBlockIndexes: [Int] + public let tokensPerSecond: Int + + enum CodingKeys: String, CodingKey { + case depth + case hiddenSize = "hidden_size" + case intermediateSize = "intermediate_size" + case outHiddenSize = "out_hidden_size" + case numHeads = "num_heads" + case patchSize = "patch_size" + case _inChannels = "in_channels" + case _layerNormEps = "layer_norm_eps" + case spatialPatchSize = "spatial_patch_size" + case spatialMergeSize = "spatial_merge_size" + case temporalPatchSize = "temporal_patch_size" + case windowSize = "window_size" + case fullAttBlockIndexes = "fullatt_block_indexes" + case tokensPerSecond = "tokens_per_second" + } + } + + public struct BaseConfiguration: Codable, Sendable { + public let modelType: String + public let vocabularySize: Int + public let imageTokenId: Int + public let videoTokenId: Int + public let hiddenSize: Int + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case vocabularySize = "vocab_size" + case imageTokenId = "image_token_id" + case videoTokenId = "video_token_id" + case hiddenSize = "hidden_size" + } + } + + public let textConfiguration: TextConfiguration + public let visionConfiguration: VisionConfiguration + public let baseConfiguration: BaseConfiguration + + enum CodingKeys: String, CodingKey { + case visionConfiguration = "vision_config" + } + + public init(from decoder: any Swift.Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + // Vision config is in a sub-dictionary + self.visionConfiguration = try container.decode( + VisionConfiguration.self, + forKey: .visionConfiguration + ) + + // Text and base configs are overlaid in the top level + self.textConfiguration = try TextConfiguration(from: decoder) + self.baseConfiguration = try BaseConfiguration(from: decoder) + } +} + +public class Qwen25VLProcessor: UserInputProcessor { + private let config: Qwen25VLProcessorConfiguration + private let tokenizer: any Tokenizer + + public init(_ config: Qwen25VLProcessorConfiguration, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + // image_processing_qwen2_vl.smart_resize + private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) + throws -> (Int, Int) + { + if height < factor { + throw VLMError.imageProcessingFailure( + "height: \(height) must be larger than factor: \(factor)") + } + if width < factor { + throw VLMError.imageProcessingFailure( + "width: \(width) must be larger than factor: \(factor)") + } + if max(height, width) / min(height, width) > 200 { + throw VLMError.imageProcessingFailure( + "absolute aspect ratio must be smaller than 200: \(width)x\(height)") + } + + var hBar = Int(round(Float(height) / Float(factor))) * factor + var wBar = Int(round(Float(width) / Float(factor))) * factor + + if hBar * wBar > maxPixels { + let beta = sqrt(Float(height * width) / Float(maxPixels)) + hBar = Int(floor(Float(height) / beta / Float(factor))) * factor + wBar = Int(floor(Float(width) / beta / Float(factor))) * factor + } else if hBar * wBar < minPixels { + let beta = sqrt(Float(minPixels) / Float(height * width)) + hBar = Int(floor(Float(height) * beta / Float(factor))) * factor + wBar = Int(floor(Float(width) * beta / Float(factor))) * factor + } + return (hBar, wBar) + } + + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( + MLXArray, THW + ) { + // first apply the user requested resizing, etc. if any + let images = images.map { MediaProcessing.apply($0, processing: processing) } + + // image_processing_qwen2_vl._preprocess + + let size = images[0].extent.size + let (resizedHeight, resizedWidth) = try targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.minPixels, maxPixels: config.maxPixels) + let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + + let processedImages = + try images + .map { + MediaProcessing.inSRGBToneCurveSpace($0) + } + .map { + return MediaProcessing.resampleBicubic($0, to: resizedSize) + } + .map { + MediaProcessing.normalize( + $0, mean: config.imageMeanTuple, std: config.imageStdTuple) + } + .map { + MediaProcessing.asMLXArray($0) + } + + var patches = concatenated(processedImages) + let mod = patches.dim(0) % config.temporalPatchSize + if mod != 0 { + let lastPatch = patches[-1, .ellipsis] + let lastPatchRepeated = tiled( + lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) + patches = concatenated([patches, lastPatchRepeated]) + } + let channel = patches.dim(1) + let gridT = patches.dim(0) / self.config.temporalPatchSize + let gridH = resizedHeight / self.config.patchSize + let gridW = resizedWidth / self.config.patchSize + + patches = patches.reshaped( + gridT, + config.temporalPatchSize, + channel, + gridH / config.mergeSize, + config.mergeSize, + config.patchSize, + gridW / config.mergeSize, + config.mergeSize, + config.patchSize + ) + patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) + + let flattenedPatches = patches.reshaped( + gridT * gridH * gridW, + channel * config.temporalPatchSize * config.patchSize * config.patchSize + ) + + return (flattenedPatches, .init(gridT, gridH, gridW)) + } + + public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?, videoTHW: [THW]?) -> String { + // the tokenizer does have a chat template and it expects messages + // like this: + // + // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'}, + // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}] + // + // The output of the prompt template is fed into + // image_processing_qwen2_vl.preprocess where it is further augmented + // by replacing tokens according to imageTHW. + // + // Neither the structured content nor the postprocessing of the template + // are supported in current Tokenizer/Jinja (swift) so handle that here. + + var messages = prompt.asMessages() + if messages[0]["role"] != "system" { + messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) + } + + let lastIndex = messages.count - 1 + var lastMessage = messages[lastIndex]["content"] ?? "" + + // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image + let mergeLength = config.mergeSize * config.mergeSize + for thw in imageTHW ?? [] { + lastMessage += "<|vision_start|>" + lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength) + .joined() + lastMessage += "<|vision_end|>" + } + + for thw in videoTHW ?? [] { + lastMessage += "<|vision_start|>" + lastMessage += Array(repeating: "<|video_pad|>", count: thw.product / mergeLength) + .joined() + lastMessage += "<|vision_end|>" + } + + messages[lastIndex]["content"] = lastMessage + + return + messages + .map { + "<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>" + } + .joined(separator: "\n") + + "\n<|im_start|>assistant\n" + } + + public func prepare(input: UserInput) async throws -> LMInput { + if input.images.isEmpty && input.videos.isEmpty { + // just a straight text prompt + let prompt = prepare(prompt: input.prompt, imageTHW: nil, videoTHW: nil) + let promptTokens = try tokenizer.encode(text: prompt) + return LMInput(tokens: MLXArray(promptTokens)) + } + + // image_processing_qwen2_vl.preprocess + let images = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } + + var videosAsImageSequences = [[CIImage]]() + for video in input.videos { + if let imageSequence = try? await MediaProcessing.asCIImageSequence( + video.asAVAsset(), samplesPerSecond: 2) + { + videosAsImageSequences.append(imageSequence) + } + } + let videos = try videosAsImageSequences.map { + try preprocess(images: $0, processing: input.processing) + } + + let imagePixels: MLXArray? + let image: LMInput.ProcessedImage? + if !images.isEmpty { + imagePixels = concatenated(images.map { $0.0 }) + image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 }) + } else { + imagePixels = nil + image = nil + } + + let videoPixels: MLXArray? + let video: LMInput.ProcessedVideo? + if !videos.isEmpty { + videoPixels = concatenated(videos.map { $0.0 }) + video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 }) + } else { + videoPixels = nil + video = nil + } + + let prompt = prepare( + prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw) + let promptTokens = try tokenizer.encode(text: prompt) + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + + return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video) + } +} + +public struct Qwen25VLProcessorConfiguration: Codable, Sendable { + + public let imageMean: [CGFloat] + public let imageStd: [CGFloat] + public let maxPixels: Int + public let minPixels: Int + public let mergeSize: Int + public let patchSize: Int + public let temporalPatchSize: Int + + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } + + enum CodingKeys: String, CodingKey { + case imageMean = "image_mean" + case imageStd = "image_std" + case maxPixels = "max_pixels" + case minPixels = "min_pixels" + case mergeSize = "merge_size" + case patchSize = "patch_size" + case temporalPatchSize = "temporal_patch_size" + } +} diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 73f654a9..343fdee6 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -60,6 +60,7 @@ public class ModelTypeRegistry: @unchecked Sendable { private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), + "qwen2_5_vl": create(Qwen25VLConfiguration.self, Qwen25VL.init), "idefics3": create(Idefics3Configuration.self, Idefics3.init), ] @@ -101,6 +102,8 @@ public class ProcessorTypeRegistry: @unchecked Sendable { PaliGemmaProcessorConfiguration.self, PaligGemmaProcessor.init), "Qwen2VLProcessor": create( Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), + "Qwen2_5_VLProcessor": create( + Qwen25VLProcessorConfiguration.self, Qwen25VLProcessor.init), "Idefics3Processor": create( Idefics3ProcessorConfiguration.self, Idefics3Processor.init), ] @@ -157,6 +160,11 @@ public class ModelRegistry: @unchecked Sendable { defaultPrompt: "Describe the image in English" ) + static public let qwen25VL3BInstruct4Bit = ModelConfiguration( + id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit", + defaultPrompt: "Describe the image in English" + ) + static public let smolvlminstruct4bit = ModelConfiguration( id: "mlx-community/SmolVLM-Instruct-4bit", defaultPrompt: "Describe the image in English" @@ -166,6 +174,7 @@ public class ModelRegistry: @unchecked Sendable { [ paligemma3bMix448_8bit, qwen2VL2BInstruct4Bit, + qwen25VL3BInstruct4Bit, ] } From 5871551f4b385121d425c53e08bc839d7347baf2 Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Fri, 14 Feb 2025 13:19:00 -0800 Subject: [PATCH 2/5] pulled in changes from #173 and added chat template for Qwen2.5VL as the one in config does not support image/video merged with changes from #173 and added chat template for Qwen2.5VL as the one in the config does not support image/video --- Libraries/MLXVLM/Models/Qwen25VL.swift | 469 +++++------------- Libraries/MLXVLM/Models/Qwen2VL.swift | 214 +------- Libraries/MLXVLM/Models/QwenVLProcessor.swift | 231 +++++++++ 3 files changed, 365 insertions(+), 549 deletions(-) create mode 100644 Libraries/MLXVLM/Models/QwenVLProcessor.swift diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index c8a68035..2e338709 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -60,11 +60,13 @@ private enum Language { let kvHeads: Int let headDim: Int let scale: Float + let mropeSection: [Int] @ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "k_proj") var wk: Linear @ModuleInfo(key: "v_proj") var wv: Linear @ModuleInfo(key: "o_proj") var wo: Linear + @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE public init(_ args: Qwen25VLConfiguration.TextConfiguration) { @@ -78,6 +80,21 @@ private enum Language { self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { + self.mropeSection = sequence(state: (0, array.makeIterator())) { state in + if let v = state.1.next() { + // note the *2 + state.0 += v * 2 + return state.0 + } else { + return nil + } + }.dropLast() + } else { + fatalError("rope_scaling['mrope_section'] must be array of integers") + } + self._rotaryEmbedding.wrappedValue = RoPE( dimensions: headDim, traditional: args.ropeTraditional, @@ -262,16 +279,15 @@ private enum Vision { fileprivate class VisionRotaryEmbedding: Module { let dimensions: Int let theta: Float - let inverseFreq: MLXArray init(dimensions: Int, theta: Float = 10000.0) { self.dimensions = dimensions self.theta = theta - let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions - self.inverseFreq = 1.0 / pow(theta, p) } func callAsFunction(_ sequenceLength: Int) -> MLXArray { + let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions + let inverseFreq = 1.0 / pow(theta, p) let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) return outer(seq, inverseFreq) } @@ -279,26 +295,27 @@ private enum Vision { fileprivate class PatchEmbed: Module { @ModuleInfo var proj: Conv3d + let patchSize: Int let temporalPatchSize: Int let inChannels: Int - let embedDimensions: Int + let hiddenSize: Int init( patchSize: Int = 14, temporalPatchSize: Int = 2, inChannels: Int = 3, - embedDimensions: Int = 1152 + hiddenSize: Int = 1152 ) { self.patchSize = patchSize self.temporalPatchSize = temporalPatchSize self.inChannels = inChannels - self.embedDimensions = embedDimensions + self.hiddenSize = hiddenSize let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) self._proj.wrappedValue = Conv3d( inputChannels: inChannels, - outputChannels: embedDimensions, + outputChannels: hiddenSize, kernelSize: kernelSize, stride: kernelSize, bias: false @@ -311,13 +328,14 @@ private enum Vision { ).movedAxis(source: 1, destination: 4) hiddenStates = proj(hiddenStates) - hiddenStates = hiddenStates.reshaped(-1, embedDimensions) + hiddenStates = hiddenStates.reshaped(-1, hiddenSize) return hiddenStates } } fileprivate class PatchMerger: Module { let hiddenSize: Int + @ModuleInfo(key: "ln_q") var layerNormQ: RMSNorm @ModuleInfo var mlp: (Linear, GELU, Linear) @@ -358,28 +376,34 @@ private enum Vision { func callAsFunction( _ x: MLXArray, - cuSeqlens: MLXArray, + frames: [THW], rotaryPositionEmbedding: MLXArray? = nil ) -> MLXArray { let seqLength = x.dim(0) - let qkv = qkv(x).reshaped(seqLength, 3, numHeads, -1).transposed(1, 0, 2, 3) - let (q, k, v) = ( - qkv[0].expandedDimensions(axis: 0), qkv[1].expandedDimensions(axis: 0), - qkv[2].expandedDimensions(axis: 0) - ) + let B = frames[0].t + let L = seqLength / B + + let qkv = qkv(x) + let s = split(qkv, parts: 3, axis: -1) + var (q, k, v) = (s[0], s[1], s[2]) - var queries = q - var keys = k + q = q.reshaped(seqLength, numHeads, -1) + k = k.reshaped(seqLength, numHeads, -1) + v = v.reshaped(seqLength, numHeads, -1) if let rotaryPositionEmbedding { - queries = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) - keys = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) + q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) + k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) } + q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + let output = MLXFast.scaledDotProductAttention( - queries: queries.transposed(0, 2, 1, 3), - keys: keys.transposed(0, 2, 1, 3), - values: v.transposed(0, 2, 1, 3), + queries: q, + keys: k, + values: v, scale: scale, mask: nil ) @@ -392,13 +416,13 @@ private enum Vision { fileprivate class MLP: Module, UnaryLayer { @ModuleInfo(key: "gate_proj") var gate: Linear - @ModuleInfo(key: "down_proj") var down: Linear @ModuleInfo(key: "up_proj") var up: Linear + @ModuleInfo(key: "down_proj") var down: Linear public init(dimensions: Int, hiddenDimensions: Int) { self._gate.wrappedValue = Linear(dimensions, hiddenDimensions) - self._down.wrappedValue = Linear(hiddenDimensions, dimensions) self._up.wrappedValue = Linear(dimensions, hiddenDimensions) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions) } public func callAsFunction(_ x: MLXArray) -> MLXArray { @@ -429,14 +453,14 @@ private enum Vision { func callAsFunction( _ hiddenStates: MLXArray, - cuSeqlens: MLXArray, + frames: [THW], rotaryPositionEmbedding: MLXArray ) -> MLXArray { var hiddenStates = hiddenStates + attention( norm1(hiddenStates), - cuSeqlens: cuSeqlens, + frames: frames, rotaryPositionEmbedding: rotaryPositionEmbedding ) @@ -469,7 +493,7 @@ private enum Vision { patchSize: config.patchSize, temporalPatchSize: config.temporalPatchSize, inChannels: config.inChannels, - embedDimensions: config.hiddenSize + hiddenSize: config.hiddenSize ) let headDim = config.hiddenSize / config.numHeads @@ -488,65 +512,21 @@ private enum Vision { func callAsFunction( _ hiddenStates: MLXArray, - gridThw: [THW], + frames: [THW], outputHiddenStates: Bool = false ) -> MLXArray { var hiddenStates = patchEmbed(hiddenStates) - var rotaryPosEmb = getRotaryPosEmb(gridThw) - var (windowIndex, cuWindowSeqlens) = getWindowIndex(gridThw) - - let seqlensArray = cuWindowSeqlens.asArray(Int.self) - var seen = Set() - var idx: [Int32] = [] - - for (i, x) in seqlensArray.enumerated() { - if !seen.contains(x) { - seen.insert(x) - idx.append(Int32(i)) - } - } - - let idx1 = MLXArray(idx) - cuWindowSeqlens = cuWindowSeqlens[idx1] - - let seqLen = hiddenStates.dim(0) - hiddenStates = hiddenStates.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) - hiddenStates = hiddenStates[windowIndex, 0..., 0...] - hiddenStates = hiddenStates.reshaped(seqLen, -1) - - rotaryPosEmb = rotaryPosEmb.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) - rotaryPosEmb = rotaryPosEmb[windowIndex, 0..., 0...] - rotaryPosEmb = rotaryPosEmb.reshaped(seqLen, -1) + var rotaryPosEmb = getRotaryPosEmb(frames) + var (windowIndex, cuWindowSeqlens) = getWindowIndex(frames) // Assuming grid_thw has shape (batch_size, 3) - let batchSize = gridThw.count - - var cuSeqlens: [MLXArray] = [] - for row in gridThw { - let (gridT, gridH, gridW) = row.values - let seqLen = gridH * gridW - let repeats = gridT - - // Create array with repeated values - let repeatedSeq = MLXArray.full([repeats], values: MLXArray(seqLen)) - - cuSeqlens.append(repeatedSeq) - } - - let cuSeqlensPadded = padded( - cumsum(concatenated(cuSeqlens)), - width: .init((1, 0)), - mode: .constant, - value: MLXArray(0) - ) + let batchSize = frames.count // Window processing for (layerNum, block) in blocks.enumerated() { - let cuSeqlensNow = - fullAttBlockIndexes.contains(layerNum) ? cuSeqlensPadded : cuWindowSeqlens hiddenStates = block( hiddenStates, - cuSeqlens: cuSeqlensNow, + frames: frames, rotaryPositionEmbedding: rotaryPosEmb ) } @@ -558,13 +538,12 @@ private enum Vision { return hiddenStates } - private func getRotaryPosEmb(_ gridThw: [THW]) -> MLXArray { + private func getRotaryPosEmb(_ frames: [THW]) -> MLXArray { var posIds = [MLXArray]() - for row in gridThw { + for row in frames { let (t, h, w) = row.values - // Create and process horizontal position IDs var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1) hposIds = repeated(hposIds, count: w, axis: 1) hposIds = hposIds.reshaped( @@ -573,10 +552,9 @@ private enum Vision { w / spatialMergeSize, spatialMergeSize ) - hposIds = hposIds.transposed(0, 2, 1, 3) - hposIds = hposIds.flattened() + .transposed(0, 2, 1, 3) + .flattened() - // Create and process vertical position IDs var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0) wposIds = repeated(wposIds, count: h, axis: 0) wposIds = wposIds.reshaped( @@ -585,29 +563,28 @@ private enum Vision { w / spatialMergeSize, spatialMergeSize ) - wposIds = wposIds.transposed(0, 2, 1, 3) - wposIds = wposIds.flattened() + .transposed(0, 2, 1, 3) + .flattened() - // Stack and tile position IDs let stackedPosIds = stacked([hposIds, wposIds], axis: -1) posIds.append(tiled(stackedPosIds, repetitions: [t, 1])) } let indices = concatenated(posIds, axis: 0) - let maxGridSize = gridThw.lazy.map({ max($0.h, $0.w) }).max() ?? 0 - let rotaryPosEmbFull = rotaryPositionEmbedding(maxGridSize) - let rotaryPosEmb = rotaryPosEmbFull[indices] + let maxFrameSize = frames.lazy.map({ max($0.h, $0.w) }).max() ?? 0 + let rotaryPosEmb = rotaryPositionEmbedding(maxFrameSize) + let rotaryPosEmbFull = rotaryPosEmb[indices] - return rotaryPosEmb.reshaped(indices.dim(0), -1) + return rotaryPosEmbFull.reshaped(indices.dim(0), -1) } - private func getWindowIndex(_ gridThw: [THW]) -> (MLXArray, MLXArray) { + private func getWindowIndex(_ frames: [THW]) -> (MLXArray, MLXArray) { var windowIndex = [MLXArray]() var cuWindowSeqlens = [0] var windowIndexId = [0] let vitMergerWindowSize = windowSize / spatialMergeSize / patchSize - for row in gridThw { + for row in frames { let (gridT, gridH, gridW) = row.values let llmGridH = gridH / spatialMergeSize let llmGridW = gridW / spatialMergeSize @@ -637,13 +614,15 @@ private enum Vision { vitMergerWindowSize ) - indexPadded = indexPadded.transposed(0, 1, 3, 2, 4) - indexPadded = indexPadded.reshaped( - gridT, - numWindowsH * numWindowsW, - vitMergerWindowSize, - vitMergerWindowSize - ) + indexPadded = + indexPadded + .transposed(0, 1, 3, 2, 4) + .reshaped( + gridT, + numWindowsH * numWindowsW, + vitMergerWindowSize, + vitMergerWindowSize + ) // Process sequence lengths and indices let seqlens = sum(indexPadded .!= -100, axes: [2, 3]).reshaped(-1) @@ -752,10 +731,10 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) } - private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?) + private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?) -> MLXArray { - guard let pixelValues, let gridThw else { + guard let pixelValues, let frames else { return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis]) } @@ -763,13 +742,13 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { let inputEmbeds = languageModel.model.embedTokens(inputIds) // Get hidden states from vision model - var hiddenStates = self.visionModel(pixelValues, gridThw: gridThw) + var hiddenStates = self.visionModel(pixelValues, frames: frames) if hiddenStates.ndim == 2 { hiddenStates = hiddenStates[.newAxis, 0..., 0...] } - // Merge input IDs with image features + // Insert special image tokens in the input_ids return mergeInputIdsWithImageFeatures( inputIds: inputIds, inputEmbeds: inputEmbeds, @@ -787,21 +766,25 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { var imageIndices = [Int]() for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == imageTokenIndex { + if v == imageTokenIndex || v == videoTokenIndex { imageIndices.append(i) } } - if imageIndices.isEmpty { - for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == videoTokenIndex { - imageIndices.append(i) - } - } + // Make sure shapes match before assignment + var result = inputEmbeds + if result.ndim == 2 { + result = result[.newAxis, 0..., 0...] + } + + if imageFeatures.ndim == 2 { + let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] + result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures + } else { + result[0..., MLXArray(imageIndices), 0...] = imageFeatures } - inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures - return inputEmbeds + return result } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws @@ -809,30 +792,29 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { { let dtype = visionModel.patchEmbed.proj.weight.dtype - let imageGridThw = input.image?.imageGridThw - let imagePixels = input.image?.pixels.asType(dtype) + var allPixels: MLXArray? + var allFrames: [THW] = [] - let videoGridThw = input.video?.videoGridThw - let videoPixels = input.video?.pixels.asType(dtype) - - let gridThw: [THW]? - let pixels: MLXArray? + if let imagePixels = input.image?.pixels, let imageFrames = input.image?.frames { + allPixels = imagePixels.asType(dtype) + allFrames.append(contentsOf: imageFrames) + } - if videoGridThw == nil { - gridThw = imageGridThw - pixels = imagePixels - } else { - gridThw = videoGridThw - pixels = videoPixels + if let videoPixels = input.video?.pixels, let videoFrames = input.video?.frames { + if allPixels == nil { + allPixels = videoPixels.asType(dtype) + } else { + allPixels = concatenated([allPixels!, videoPixels.asType(dtype)]) + } + allFrames.append(contentsOf: videoFrames) } let inputEmbeddings = self.inputEmbeddings( - inputIds: input.text.tokens, - pixelValues: pixels, - gridThw: gridThw - ) + inputIds: input.text.tokens, pixelValues: allPixels, + frames: allFrames.isEmpty ? nil : allFrames) let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) + return .logits(result) } @@ -974,218 +956,16 @@ public struct Qwen25VLConfiguration: Codable, Sendable { } } -public class Qwen25VLProcessor: UserInputProcessor { - private let config: Qwen25VLProcessorConfiguration - private let tokenizer: any Tokenizer - - public init(_ config: Qwen25VLProcessorConfiguration, tokenizer: any Tokenizer) { - self.config = config - self.tokenizer = tokenizer - } - - // image_processing_qwen2_vl.smart_resize - private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) - throws -> (Int, Int) - { - if height < factor { - throw VLMError.imageProcessingFailure( - "height: \(height) must be larger than factor: \(factor)") - } - if width < factor { - throw VLMError.imageProcessingFailure( - "width: \(width) must be larger than factor: \(factor)") - } - if max(height, width) / min(height, width) > 200 { - throw VLMError.imageProcessingFailure( - "absolute aspect ratio must be smaller than 200: \(width)x\(height)") - } - - var hBar = Int(round(Float(height) / Float(factor))) * factor - var wBar = Int(round(Float(width) / Float(factor))) * factor - - if hBar * wBar > maxPixels { - let beta = sqrt(Float(height * width) / Float(maxPixels)) - hBar = Int(floor(Float(height) / beta / Float(factor))) * factor - wBar = Int(floor(Float(width) / beta / Float(factor))) * factor - } else if hBar * wBar < minPixels { - let beta = sqrt(Float(minPixels) / Float(height * width)) - hBar = Int(floor(Float(height) * beta / Float(factor))) * factor - wBar = Int(floor(Float(width) * beta / Float(factor))) * factor - } - return (hBar, wBar) - } - - public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( - MLXArray, THW - ) { - // first apply the user requested resizing, etc. if any - let images = images.map { MediaProcessing.apply($0, processing: processing) } - - // image_processing_qwen2_vl._preprocess - - let size = images[0].extent.size - let (resizedHeight, resizedWidth) = try targetSize( - height: Int(size.height), width: Int(size.width), - factor: config.patchSize * config.mergeSize, - minPixels: config.minPixels, maxPixels: config.maxPixels) - let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) - - let processedImages = - try images - .map { - MediaProcessing.inSRGBToneCurveSpace($0) - } - .map { - return MediaProcessing.resampleBicubic($0, to: resizedSize) - } - .map { - MediaProcessing.normalize( - $0, mean: config.imageMeanTuple, std: config.imageStdTuple) - } - .map { - MediaProcessing.asMLXArray($0) - } - - var patches = concatenated(processedImages) - let mod = patches.dim(0) % config.temporalPatchSize - if mod != 0 { - let lastPatch = patches[-1, .ellipsis] - let lastPatchRepeated = tiled( - lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) - patches = concatenated([patches, lastPatchRepeated]) - } - let channel = patches.dim(1) - let gridT = patches.dim(0) / self.config.temporalPatchSize - let gridH = resizedHeight / self.config.patchSize - let gridW = resizedWidth / self.config.patchSize - - patches = patches.reshaped( - gridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) - patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) - - let flattenedPatches = patches.reshaped( - gridT * gridH * gridW, - channel * config.temporalPatchSize * config.patchSize * config.patchSize - ) - - return (flattenedPatches, .init(gridT, gridH, gridW)) - } - - public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?, videoTHW: [THW]?) -> String { - // the tokenizer does have a chat template and it expects messages - // like this: - // - // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'}, - // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}] - // - // The output of the prompt template is fed into - // image_processing_qwen2_vl.preprocess where it is further augmented - // by replacing tokens according to imageTHW. - // - // Neither the structured content nor the postprocessing of the template - // are supported in current Tokenizer/Jinja (swift) so handle that here. - - var messages = prompt.asMessages() - if messages[0]["role"] != "system" { - messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) - } - - let lastIndex = messages.count - 1 - var lastMessage = messages[lastIndex]["content"] ?? "" - - // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image - let mergeLength = config.mergeSize * config.mergeSize - for thw in imageTHW ?? [] { - lastMessage += "<|vision_start|>" - lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength) - .joined() - lastMessage += "<|vision_end|>" - } - - for thw in videoTHW ?? [] { - lastMessage += "<|vision_start|>" - lastMessage += Array(repeating: "<|video_pad|>", count: thw.product / mergeLength) - .joined() - lastMessage += "<|vision_end|>" - } - - messages[lastIndex]["content"] = lastMessage +// MARK: - Processor - return - messages - .map { - "<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>" - } - .joined(separator: "\n") - + "\n<|im_start|>assistant\n" - } - - public func prepare(input: UserInput) async throws -> LMInput { - if input.images.isEmpty && input.videos.isEmpty { - // just a straight text prompt - let prompt = prepare(prompt: input.prompt, imageTHW: nil, videoTHW: nil) - let promptTokens = try tokenizer.encode(text: prompt) - return LMInput(tokens: MLXArray(promptTokens)) - } - - // image_processing_qwen2_vl.preprocess - let images = try input.images.map { - try preprocess(images: [$0.asCIImage()], processing: input.processing) - } - - var videosAsImageSequences = [[CIImage]]() - for video in input.videos { - if let imageSequence = try? await MediaProcessing.asCIImageSequence( - video.asAVAsset(), samplesPerSecond: 2) - { - videosAsImageSequences.append(imageSequence) - } - } - let videos = try videosAsImageSequences.map { - try preprocess(images: $0, processing: input.processing) - } - - let imagePixels: MLXArray? - let image: LMInput.ProcessedImage? - if !images.isEmpty { - imagePixels = concatenated(images.map { $0.0 }) - image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 }) - } else { - imagePixels = nil - image = nil - } - - let videoPixels: MLXArray? - let video: LMInput.ProcessedVideo? - if !videos.isEmpty { - videoPixels = concatenated(videos.map { $0.0 }) - video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 }) - } else { - videoPixels = nil - video = nil - } - - let prompt = prepare( - prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw) - let promptTokens = try tokenizer.encode(text: prompt) - let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) - let mask = ones(like: promptArray).asType(.int8) - - return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video) - } -} - -public struct Qwen25VLProcessorConfiguration: Codable, Sendable { +/// Qwen25VL VLM `UserInputProcessor`. +/// +/// This is meant to be used with ``Qwen25VL`` and is typically created by ``VLMModelFactory``. +/// +public typealias Qwen25VLProcessor = QwenVLProcessor +// Configuration for ``Qwen25VLProcessor`` +public struct Qwen25VLProcessorConfiguration: QwenVLProcessorConfiguration { public let imageMean: [CGFloat] public let imageStd: [CGFloat] public let maxPixels: Int @@ -1194,13 +974,6 @@ public struct Qwen25VLProcessorConfiguration: Codable, Sendable { public let patchSize: Int public let temporalPatchSize: Int - public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { - (imageMean[0], imageMean[1], imageMean[2]) - } - public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { - (imageStd[0], imageStd[1], imageStd[2]) - } - enum CodingKeys: String, CodingKey { case imageMean = "image_mean" case imageStd = "image_std" @@ -1210,4 +983,12 @@ public struct Qwen25VLProcessorConfiguration: Codable, Sendable { case patchSize = "patch_size" case temporalPatchSize = "temporal_patch_size" } + + private var chatTemplate: String { + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + } + + public func applyChatTemplate(messages: [Message], tokenizer: any Tokenizer) throws -> [Int] { + return try tokenizer.applyChatTemplate(messages: messages, chatTemplate: chatTemplate) + } } diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index f71e2352..7e168635 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -583,203 +583,8 @@ private enum Vision { /// Qwen2VL VLM `UserInputProcessor`. /// /// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``. -public class Qwen2VLProcessor: UserInputProcessor { - private let config: Qwen2VLProcessorConfiguration - private let tokenizer: any Tokenizer - - public init(_ config: Qwen2VLProcessorConfiguration, tokenizer: any Tokenizer) { - self.config = config - self.tokenizer = tokenizer - } - - // image_processing_qwen2_vl.smart_resize - private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) - throws -> (Int, Int) - { - if height < factor { - throw VLMError.imageProcessingFailure( - "height: \(height) must be larger than factor: \(factor)") - } - if width < factor { - throw VLMError.imageProcessingFailure( - "width: \(width) must be larger than factor: \(factor)") - } - if max(height, width) / min(height, width) > 200 { - throw VLMError.imageProcessingFailure( - "absolute aspect ratio must be smaller than 200: \(width)x\(height)") - } - - var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) - var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) - - if hBar * wBar > maxPixels { - let beta = sqrt(Float(height * width) / Float(maxPixels)) - hBar = Int(floor(Float(height) / beta / Float(factor))) * factor - wBar = Int(floor(Float(width) / beta / Float(factor))) * factor - } else if hBar * wBar < minPixels { - let beta = sqrt(Float(minPixels) / Float(height * width)) - hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor - wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor - } - return (hBar, wBar) - } - - public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( - MLXArray, THW - ) { - // first apply the user requested resizing, etc. if any - let images = images.map { MediaProcessing.apply($0, processing: processing) } - - // image_processing_qwen2_vl._preprocess - - let size = images[0].extent.size - let (resizedHeight, resizedWidth) = try targetSize( - height: Int(size.height), width: Int(size.width), - factor: config.patchSize * config.mergeSize, - minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) - let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) - - let processedImages = - try images - .map { - MediaProcessing.inSRGBToneCurveSpace($0) - } - .map { - return MediaProcessing.resampleBicubic($0, to: resizedSize) - } - .map { - MediaProcessing.normalize( - $0, mean: config.imageMeanTuple, std: config.imageStdTuple) - } - .map { - MediaProcessing.asMLXArray($0) - } - - var patches = concatenated(processedImages) - let mod = patches.dim(0) % config.temporalPatchSize - if mod != 0 { - let lastPatch = patches[-1, .ellipsis] - let lastPatchRepeated = tiled( - lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) - patches = concatenated([patches, lastPatchRepeated]) - } - let channel = patches.dim(1) - let gridT = patches.dim(0) / self.config.temporalPatchSize - let gridH = resizedHeight / self.config.patchSize - let gridW = resizedWidth / self.config.patchSize - - patches = patches.reshaped( - gridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) - patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) - - let flattenedPatches = patches.reshaped( - gridT * gridH * gridW, - channel * config.temporalPatchSize * config.patchSize * config.patchSize - ) - - return (flattenedPatches, .init(gridT, gridH, gridW)) - } - - public func prepare(input: UserInput) async throws -> LMInput { - let messages = input.prompt.asMessages() - var promptTokens = try tokenizer.applyChatTemplate(messages: messages) - - // Text-only input - if input.images.isEmpty, input.videos.isEmpty { - return LMInput(tokens: MLXArray(promptTokens)) - } - - // Process images if any - var processedImage: LMInput.ProcessedImage? - if !input.images.isEmpty { - let imagePixelsAndFrames = try input.images.map { - try preprocess(images: [$0.asCIImage()], processing: input.processing) - } - let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) - processedImage = LMInput.ProcessedImage( - pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) - if let imageFrames = processedImage?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>") - } - } - - // Process videos if any - var processedVideo: LMInput.ProcessedVideo? - if !input.videos.isEmpty { - var videosAsImageSequences = [[CIImage]]() - for video in input.videos { - if let imageSequence = try? await MediaProcessing.asCIImageSequence( - video.asAVAsset(), samplesPerSecond: 2) - { - videosAsImageSequences.append(imageSequence) - } - } - let videoPixelsAndFrames = try videosAsImageSequences.map { - try preprocess(images: $0, processing: input.processing) - } - let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) - processedVideo = LMInput.ProcessedVideo( - pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) - if let videoFrames = processedVideo?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>") - } - } - - let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) - let mask = ones(like: promptArray).asType(.int8) - return LMInput( - text: .init(tokens: promptArray, mask: mask), - image: processedImage, - video: processedVideo) - } - - func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) - throws -> [Int] - { - // Replace single padding token with correct number for each image or video frame - let placeholderTokens = try tokenizer.encode( - text: "<|vision_start|>\(paddingToken)<|vision_end|>") - let placeholderRanges = promptTokens.ranges(of: placeholderTokens) - guard placeholderRanges.count == frames.count else { - throw VLMError.processing( - "Number of placeholder tokens does not match number of frames") - } - let mergeLength = config.mergeSize * config.mergeSize - let replacementSequences = try frames.map { frame in - let paddingCount = frame.product / mergeLength - return try tokenizer.encode( - text: - "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" - ) - } - // Build the final array - var result: [Int] = [] - var currentIndex = promptTokens.startIndex - for (range, replacement) in zip(placeholderRanges, replacementSequences) { - // Add tokens before the placeholder - result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) - // Add replacement sequence - result.append(contentsOf: replacement) - currentIndex = range.upperBound - } - // Add any remaining tokens after the last replacement - if currentIndex < promptTokens.endIndex { - result.append(contentsOf: promptTokens[currentIndex...]) - } - return result - } -} +/// +public typealias Qwen2VLProcessor = QwenVLProcessor // MARK: - Model @@ -1026,8 +831,7 @@ public struct Qwen2VLConfiguration: Codable, Sendable { } /// Configuration for ``Qwen2VLProcessor`` -public struct Qwen2VLProcessorConfiguration: Codable, Sendable { - +public struct Qwen2VLProcessorConfiguration: QwenVLProcessorConfiguration { public struct Size: Codable, Sendable { public let maxPixels: Int public let minPixels: Int @@ -1045,12 +849,8 @@ public struct Qwen2VLProcessorConfiguration: Codable, Sendable { public let patchSize: Int public let temporalPatchSize: Int - public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { - (imageMean[0], imageMean[1], imageMean[2]) - } - public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { - (imageStd[0], imageStd[1], imageStd[2]) - } + public var maxPixels: Int { size.maxPixels } + public var minPixels: Int { size.minPixels } enum CodingKeys: String, CodingKey { case imageMean = "image_mean" @@ -1060,4 +860,8 @@ public struct Qwen2VLProcessorConfiguration: Codable, Sendable { case patchSize = "patch_size" case temporalPatchSize = "temporal_patch_size" } + + public func applyChatTemplate(messages: [Message], tokenizer: any Tokenizer) throws -> [Int] { + return try tokenizer.applyChatTemplate(messages: messages) + } } diff --git a/Libraries/MLXVLM/Models/QwenVLProcessor.swift b/Libraries/MLXVLM/Models/QwenVLProcessor.swift new file mode 100644 index 00000000..b82a4a99 --- /dev/null +++ b/Libraries/MLXVLM/Models/QwenVLProcessor.swift @@ -0,0 +1,231 @@ +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +public protocol QwenVLProcessorConfiguration: Codable, Sendable { + var imageMean: [CGFloat] { get } + var imageStd: [CGFloat] { get } + var maxPixels: Int { get } + var minPixels: Int { get } + var mergeSize: Int { get } + var patchSize: Int { get } + var temporalPatchSize: Int { get } + + var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { get } + var imageStdTuple: (CGFloat, CGFloat, CGFloat) { get } + + func applyChatTemplate(messages: [Message], tokenizer: any Tokenizer) throws -> [Int] +} + +// Default implementation for common properties +extension QwenVLProcessorConfiguration { + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } +} + +// Base processor class +public class QwenVLProcessor: UserInputProcessor { + private let config: Config + private let tokenizer: any Tokenizer + + public init(_ config: Config, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) + throws -> (Int, Int) + { + if height < factor { + throw VLMError.imageProcessingFailure( + "height: \(height) must be larger than factor: \(factor)") + } + if width < factor { + throw VLMError.imageProcessingFailure( + "width: \(width) must be larger than factor: \(factor)") + } + if max(height, width) / min(height, width) > 200 { + throw VLMError.imageProcessingFailure( + "absolute aspect ratio must be smaller than 200: \(width)x\(height)") + } + + var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) + var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) + + if hBar * wBar > maxPixels { + let beta = sqrt(Float(height * width) / Float(maxPixels)) + hBar = Int(floor(Float(height) / beta / Float(factor))) * factor + wBar = Int(floor(Float(width) / beta / Float(factor))) * factor + } else if hBar * wBar < minPixels { + let beta = sqrt(Float(minPixels) / Float(height * width)) + hBar = Int(floor(Float(height) * beta / Float(factor))) * factor + wBar = Int(floor(Float(width) * beta / Float(factor))) * factor + } + return (hBar, wBar) + } + + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( + MLXArray, THW + ) { + // first apply the user requested resizing, etc. if any + let images = images.map { MediaProcessing.apply($0, processing: processing) } + + // image_processing_qwen2_vl._preprocess + + let size = images[0].extent.size + let (resizedHeight, resizedWidth) = try targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.minPixels, maxPixels: config.maxPixels) + let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + + let processedImages = + try images + .map { + MediaProcessing.inSRGBToneCurveSpace($0) + } + .map { + MediaProcessing.resampleBicubic($0, to: resizedSize) + } + .map { + MediaProcessing.normalize( + $0, mean: config.imageMeanTuple, std: config.imageStdTuple) + } + .map { + MediaProcessing.asMLXArray($0) + } + + var patches = concatenated(processedImages) + let mod = patches.dim(0) % config.temporalPatchSize + if mod != 0 { + let lastPatch = patches[-1, .ellipsis] + let lastPatchRepeated = tiled( + lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) + patches = concatenated([patches, lastPatchRepeated]) + } + let channel = patches.dim(1) + let gridT = patches.dim(0) / self.config.temporalPatchSize + let gridH = resizedHeight / self.config.patchSize + let gridW = resizedWidth / self.config.patchSize + + patches = patches.reshaped( + gridT, + config.temporalPatchSize, + channel, + gridH / config.mergeSize, + config.mergeSize, + config.patchSize, + gridW / config.mergeSize, + config.mergeSize, + config.patchSize + ) + patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) + + let flattenedPatches = patches.reshaped( + gridT * gridH * gridW, + channel * config.temporalPatchSize * config.patchSize * config.patchSize + ) + + return (flattenedPatches, .init(gridT, gridH, gridW)) + } + + public func prepare(input: UserInput) async throws -> LMInput { + let messages = input.prompt.asMessages() + var promptTokens = try config.applyChatTemplate(messages: messages, tokenizer: tokenizer) + + // Text-only input + if input.images.isEmpty, input.videos.isEmpty { + return LMInput(tokens: MLXArray(promptTokens)) + } + + // Process images if any + var processedImage: LMInput.ProcessedImage? + if !input.images.isEmpty { + let imagePixelsAndFrames = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } + let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) + processedImage = LMInput.ProcessedImage( + pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) + if let imageFrames = processedImage?.frames { + promptTokens = try replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>") + } + } + + // Process videos if any + var processedVideo: LMInput.ProcessedVideo? + if !input.videos.isEmpty { + var videosAsImageSequences = [[CIImage]]() + for video in input.videos { + if let imageSequence = try? await MediaProcessing.asCIImageSequence( + video.asAVAsset(), samplesPerSecond: 2) + { + videosAsImageSequences.append(imageSequence) + } + } + let videoPixelsAndFrames = try videosAsImageSequences.map { + try preprocess(images: $0, processing: input.processing) + } + let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) + processedVideo = LMInput.ProcessedVideo( + pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) + if let videoFrames = processedVideo?.frames { + promptTokens = try replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>") + } + } + + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + return LMInput( + text: .init(tokens: promptArray, mask: mask), + image: processedImage, + video: processedVideo) + } + + func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) + throws -> [Int] + { + // Replace single padding token with correct number for each image or video frame + let placeholderTokens = try tokenizer.encode( + text: "<|vision_start|>\(paddingToken)<|vision_end|>") + let placeholderRanges = promptTokens.ranges(of: placeholderTokens) + guard placeholderRanges.count == frames.count else { + throw VLMError.processing( + "Number of placeholder tokens does not match number of frames") + } + let mergeLength = config.mergeSize * config.mergeSize + let replacementSequences = try frames.map { frame in + let paddingCount = frame.product / mergeLength + return try tokenizer.encode( + text: + "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" + ) + } + // Build the final array + var result: [Int] = [] + var currentIndex = promptTokens.startIndex + for (range, replacement) in zip(placeholderRanges, replacementSequences) { + // Add tokens before the placeholder + result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) + // Add replacement sequence + result.append(contentsOf: replacement) + currentIndex = range.upperBound + } + // Add any remaining tokens after the last replacement + if currentIndex < promptTokens.endIndex { + result.append(contentsOf: promptTokens[currentIndex...]) + } + return result + } +} From 392dee5fd9b2b76f044b18b99f087fdf9945484a Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Thu, 6 Feb 2025 20:06:28 -0800 Subject: [PATCH 3/5] remove unused imports --- Libraries/MLXVLM/Models/QwenVLProcessor.swift | 3 --- 1 file changed, 3 deletions(-) diff --git a/Libraries/MLXVLM/Models/QwenVLProcessor.swift b/Libraries/MLXVLM/Models/QwenVLProcessor.swift index b82a4a99..4d7a0028 100644 --- a/Libraries/MLXVLM/Models/QwenVLProcessor.swift +++ b/Libraries/MLXVLM/Models/QwenVLProcessor.swift @@ -1,10 +1,7 @@ import CoreImage import Foundation -import Hub import MLX -import MLXFast import MLXLMCommon -import MLXNN import Tokenizers public protocol QwenVLProcessorConfiguration: Codable, Sendable { From 270bb8b18140b4fe4be6c58e311ed5971b166c5b Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Tue, 4 Mar 2025 15:19:17 -0800 Subject: [PATCH 4/5] recfactor common code between Qwen2VL and Qwen25VL --- Libraries/MLXVLM/Models/Qwen25VL.swift | 566 ++-------------- Libraries/MLXVLM/Models/Qwen2VL.swift | 385 ++--------- Libraries/MLXVLM/Models/QwenVL.swift | 632 ++++++++++++++++++ Libraries/MLXVLM/Models/QwenVLProcessor.swift | 228 ------- Package.swift | 2 +- 5 files changed, 744 insertions(+), 1069 deletions(-) create mode 100644 Libraries/MLXVLM/Models/QwenVL.swift delete mode 100644 Libraries/MLXVLM/Models/QwenVLProcessor.swift diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index 2e338709..ab54f4ec 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -16,157 +16,33 @@ import MLXLMCommon import MLXNN import Tokenizers -// MARK: - Common - -/// Rotates half the hidden dims of the input -private func rotateHalf(_ x: MLXArray) -> MLXArray { - let index = x.dim(-1) / 2 - let x1 = x[.ellipsis, 0 ..< index] - let x2 = x[.ellipsis, index...] - return concatenated([-x2, x1], axis: -1) -} - // MARK: - Language private enum Language { - static private func applyMultimodalRotaryPositionEmbedding( - q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, - positionIds: MLXArray, mropeSection: [Int] - ) -> (MLXArray, MLXArray) { - var cos = cos[positionIds] - var sin = sin[positionIds] - - cos = - concatenated( - // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] - split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, - axis: -1 - )[0..., .newAxis, 0..., 0...] - - sin = - concatenated( - split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, - axis: -1 - )[0..., .newAxis, 0..., 0...] - - // Apply rotary embedding - let qEmbed = (q * cos) + (rotateHalf(q) * sin) - let kEmbed = (k * cos) + (rotateHalf(k) * sin) - return (qEmbed, kEmbed) - } - - fileprivate class Attention: Module { - let heads: Int - let kvHeads: Int - let headDim: Int - let scale: Float - let mropeSection: [Int] - - @ModuleInfo(key: "q_proj") var wq: Linear - @ModuleInfo(key: "k_proj") var wk: Linear - @ModuleInfo(key: "v_proj") var wv: Linear - @ModuleInfo(key: "o_proj") var wo: Linear - - @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE - + fileprivate class Attention: QwenVLLanguage.Attention { public init(_ args: Qwen25VLConfiguration.TextConfiguration) { - let dim = args.hiddenSize - self.heads = args.attentionHeads - self.kvHeads = args.kvHeads - self.headDim = dim / heads - self.scale = pow(Float(headDim), -0.5) - - self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) - self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) - self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) - self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) - - if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { - self.mropeSection = sequence(state: (0, array.makeIterator())) { state in - if let v = state.1.next() { - // note the *2 - state.0 += v * 2 - return state.0 - } else { - return nil - } - }.dropLast() - } else { - fatalError("rope_scaling['mrope_section'] must be array of integers") - } - - self._rotaryEmbedding.wrappedValue = RoPE( - dimensions: headDim, - traditional: args.ropeTraditional, - base: args.ropeTheta + super.init( + hiddenSize: args.hiddenSize, + attentionHeads: args.attentionHeads, + kvHeads: args.kvHeads, + ropeTheta: args.ropeTheta, + ropeTraditional: args.ropeTraditional, + ropeScaling: args.ropeScaling ) } - - public func callAsFunction( - _ x: MLXArray, - mask: MLXArray? = nil, - cache: KVCache? - ) -> MLXArray { - let (B, L) = (x.dim(0), x.dim(1)) - - var queries = wq(x) - var keys = wk(x) - var values = wv(x) - - queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) - keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) - values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) - - let offset = cache?.offset ?? 0 - let mask = mask?[0..., 0 ..< keys.dim(-2)] - - queries = rotaryEmbedding(queries, offset: offset) - keys = rotaryEmbedding(keys, offset: offset) - - if let cache { - (keys, values) = cache.update(keys: keys, values: values) - } - - let output = MLXFast.scaledDotProductAttention( - queries: queries, - keys: keys, - values: values, - scale: scale, - mask: mask - ) - .transposed(0, 2, 1, 3) - .reshaped(B, L, -1) - - return wo(output) - } - } - - fileprivate class MLP: Module { - @ModuleInfo(key: "gate_proj") var gate: Linear - @ModuleInfo(key: "down_proj") var down: Linear - @ModuleInfo(key: "up_proj") var up: Linear - - public init(dimensions: Int, hiddenDimensions: Int) { - self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) - self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) - self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) - } - - public func callAsFunction(_ x: MLXArray) -> MLXArray { - down(silu(gate(x)) * up(x)) - } } fileprivate class Qwen25VLDecoderLayer: Module { @ModuleInfo(key: "self_attn") var attention: Attention - let mlp: MLP + let mlp: QwenVLLanguage.MLP @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm public init(_ args: Qwen25VLConfiguration.TextConfiguration) { self._attention.wrappedValue = Attention(args) - self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self.mlp = QwenVLLanguage.MLP( + dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) self._inputLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) self._postAttentionLayerNorm.wrappedValue = RMSNorm( @@ -258,162 +134,6 @@ private enum Language { // MARK: - Vision private enum Vision { - static fileprivate func applyMultimodalRotaryPositionEmbedding( - _ tensor: MLXArray, freqs: MLXArray - ) -> MLXArray { - var cos = cos(freqs) - var sin = sin(freqs) - - cos = expandedDimensions(cos, axis: 1) - cos = tiled(cos, repetitions: [1, 1, 2]) - cos = expandedDimensions(cos, axis: 0) - - sin = expandedDimensions(sin, axis: 1) - sin = tiled(sin, repetitions: [1, 1, 2]) - sin = expandedDimensions(sin, axis: 0) - - let output = (tensor * cos) + (rotateHalf(tensor) * sin) - return output.asType(tensor.dtype) - } - - fileprivate class VisionRotaryEmbedding: Module { - let dimensions: Int - let theta: Float - - init(dimensions: Int, theta: Float = 10000.0) { - self.dimensions = dimensions - self.theta = theta - } - - func callAsFunction(_ sequenceLength: Int) -> MLXArray { - let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions - let inverseFreq = 1.0 / pow(theta, p) - let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) - return outer(seq, inverseFreq) - } - } - - fileprivate class PatchEmbed: Module { - @ModuleInfo var proj: Conv3d - - let patchSize: Int - let temporalPatchSize: Int - let inChannels: Int - let hiddenSize: Int - - init( - patchSize: Int = 14, - temporalPatchSize: Int = 2, - inChannels: Int = 3, - hiddenSize: Int = 1152 - ) { - self.patchSize = patchSize - self.temporalPatchSize = temporalPatchSize - self.inChannels = inChannels - self.hiddenSize = hiddenSize - - let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) - self._proj.wrappedValue = Conv3d( - inputChannels: inChannels, - outputChannels: hiddenSize, - kernelSize: kernelSize, - stride: kernelSize, - bias: false - ) - } - - func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { - var hiddenStates = hiddenStates.reshaped( - -1, inChannels, temporalPatchSize, patchSize, patchSize - ).movedAxis(source: 1, destination: 4) - - hiddenStates = proj(hiddenStates) - hiddenStates = hiddenStates.reshaped(-1, hiddenSize) - return hiddenStates - } - } - - fileprivate class PatchMerger: Module { - let hiddenSize: Int - - @ModuleInfo(key: "ln_q") var layerNormQ: RMSNorm - @ModuleInfo var mlp: (Linear, GELU, Linear) - - init(dimensions: Int, contextDimensions: Int, spatialMergeSize: Int = 2) { - self.hiddenSize = contextDimensions * (spatialMergeSize * spatialMergeSize) - self._layerNormQ.wrappedValue = RMSNorm(dimensions: contextDimensions, eps: 1e-6) - self.mlp = ( - Linear(hiddenSize, hiddenSize), - GELU(), - Linear(hiddenSize, dimensions) - ) - } - - func callAsFunction(_ x: MLXArray) -> MLXArray { - var x = layerNormQ(x).reshaped(-1, hiddenSize) - x = mlp.0(x) - x = mlp.1(x) - x = mlp.2(x) - return x - } - } - - fileprivate class Attention: Module { - let numHeads: Int - let scale: Float - - @ModuleInfo(key: "qkv") var qkv: Linear - @ModuleInfo(key: "proj") var proj: Linear - - init(dim: Int, numHeads: Int = 16) { - self.numHeads = numHeads - let headDim = dim / numHeads - self.scale = pow(Float(headDim), -0.5) - - self._qkv.wrappedValue = Linear(dim, dim * 3, bias: true) - self._proj.wrappedValue = Linear(dim, dim) - } - - func callAsFunction( - _ x: MLXArray, - frames: [THW], - rotaryPositionEmbedding: MLXArray? = nil - ) -> MLXArray { - let seqLength = x.dim(0) - let B = frames[0].t - let L = seqLength / B - - let qkv = qkv(x) - let s = split(qkv, parts: 3, axis: -1) - var (q, k, v) = (s[0], s[1], s[2]) - - q = q.reshaped(seqLength, numHeads, -1) - k = k.reshaped(seqLength, numHeads, -1) - v = v.reshaped(seqLength, numHeads, -1) - - if let rotaryPositionEmbedding { - q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) - k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) - } - - q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - - let output = MLXFast.scaledDotProductAttention( - queries: q, - keys: k, - values: v, - scale: scale, - mask: nil - ) - .transposed(0, 2, 1, 3) - .reshaped(seqLength, -1) - - return proj(output) - } - } - fileprivate class MLP: Module, UnaryLayer { @ModuleInfo(key: "gate_proj") var gate: Linear @ModuleInfo(key: "up_proj") var up: Linear @@ -431,17 +151,16 @@ private enum Vision { } fileprivate class Qwen25VLVisionBlock: Module { - @ModuleInfo var norm1: RMSNorm - @ModuleInfo var norm2: RMSNorm - @ModuleInfo(key: "attn") var attention: Attention + @ModuleInfo var norm1: LayerNorm + @ModuleInfo var norm2: LayerNorm + @ModuleInfo(key: "attn") var attention: QwenVLVision.Attention @ModuleInfo var mlp: MLP init(_ config: Qwen25VLConfiguration.VisionConfiguration) { - self.norm1 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) - self.norm2 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) - - self._attention.wrappedValue = Attention( - dim: config.hiddenSize, + self.norm1 = LayerNorm(dimensions: config.hiddenSize, eps: 1e-6) + self.norm2 = LayerNorm(dimensions: config.hiddenSize, eps: 1e-6) + self._attention.wrappedValue = QwenVLVision.Attention( + dims: config.hiddenSize, numHeads: config.numHeads ) @@ -470,34 +189,25 @@ private enum Vision { } fileprivate class VisionModel: Module { - @ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed - @ModuleInfo(key: "merger") var merger: PatchMerger - @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding + @ModuleInfo(key: "patch_embed") var patchEmbed: QwenVLVision.PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: + QwenVLVision.VisionRotaryEmbedding @ModuleInfo(key: "blocks") var blocks: [Qwen25VLVisionBlock] + @ModuleInfo(key: "merger") var patchMerger: QwenVLVision.PatchMerger - let windowSize: Int - let patchSize: Int let spatialMergeSize: Int - let spatialMergeUnit: Int - let fullAttBlockIndexes: [Int] init(_ config: Qwen25VLConfiguration.VisionConfiguration) { - self.windowSize = config.windowSize - self.patchSize = config.patchSize self.spatialMergeSize = config.spatialMergeSize - self.fullAttBlockIndexes = config.fullAttBlockIndexes - - self.spatialMergeUnit = spatialMergeSize * spatialMergeSize - - self._patchEmbed.wrappedValue = PatchEmbed( + self._patchEmbed.wrappedValue = QwenVLVision.PatchEmbed( patchSize: config.patchSize, temporalPatchSize: config.temporalPatchSize, inChannels: config.inChannels, - hiddenSize: config.hiddenSize + embedDimensions: config.hiddenSize ) let headDim = config.hiddenSize / config.numHeads - self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding( + self._rotaryPositionEmbedding.wrappedValue = QwenVLVision.VisionRotaryEmbedding( dimensions: headDim / 2 ) @@ -505,7 +215,7 @@ private enum Vision { Qwen25VLVisionBlock(config) } - self._merger.wrappedValue = PatchMerger( + self._patchMerger.wrappedValue = QwenVLVision.PatchMerger( dimensions: config.outHiddenSize, contextDimensions: config.hiddenSize ) } @@ -516,30 +226,19 @@ private enum Vision { outputHiddenStates: Bool = false ) -> MLXArray { var hiddenStates = patchEmbed(hiddenStates) - var rotaryPosEmb = getRotaryPosEmb(frames) - var (windowIndex, cuWindowSeqlens) = getWindowIndex(frames) + var rotaryPositionEmbedding = rotaryPositionEmbedding(frames) - // Assuming grid_thw has shape (batch_size, 3) let batchSize = frames.count - - // Window processing - for (layerNum, block) in blocks.enumerated() { + for block in blocks { hiddenStates = block( - hiddenStates, - frames: frames, - rotaryPositionEmbedding: rotaryPosEmb - ) + hiddenStates, frames: frames, + rotaryPositionEmbedding: rotaryPositionEmbedding) } - - hiddenStates = merger(hiddenStates) - let reverseIndices = argSort(windowIndex, axis: 0) - hiddenStates = hiddenStates[reverseIndices, 0...] - - return hiddenStates + return patchMerger(hiddenStates) } - private func getRotaryPosEmb(_ frames: [THW]) -> MLXArray { - var posIds = [MLXArray]() + private func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray { + var positionIds = [MLXArray]() for row in frames { let (t, h, w) = row.values @@ -567,122 +266,14 @@ private enum Vision { .flattened() let stackedPosIds = stacked([hposIds, wposIds], axis: -1) - posIds.append(tiled(stackedPosIds, repetitions: [t, 1])) + positionIds.append(tiled(stackedPosIds, repetitions: [t, 1])) } - let indices = concatenated(posIds, axis: 0) + let indices = concatenated(positionIds, axis: 0) let maxFrameSize = frames.lazy.map({ max($0.h, $0.w) }).max() ?? 0 - let rotaryPosEmb = rotaryPositionEmbedding(maxFrameSize) - let rotaryPosEmbFull = rotaryPosEmb[indices] + let rotaryPositionEmbedFull = rotaryPositionEmbedding(maxFrameSize)[indices] - return rotaryPosEmbFull.reshaped(indices.dim(0), -1) - } - - private func getWindowIndex(_ frames: [THW]) -> (MLXArray, MLXArray) { - var windowIndex = [MLXArray]() - var cuWindowSeqlens = [0] - var windowIndexId = [0] - let vitMergerWindowSize = windowSize / spatialMergeSize / patchSize - - for row in frames { - let (gridT, gridH, gridW) = row.values - let llmGridH = gridH / spatialMergeSize - let llmGridW = gridW / spatialMergeSize - - // Create initial index array - let index = MLXArray(0 ..< (gridT * llmGridH * llmGridW)).reshaped( - gridT, llmGridH, llmGridW) - - // Calculate padding and window dimensions - let padH = vitMergerWindowSize - llmGridH % vitMergerWindowSize - let padW = vitMergerWindowSize - llmGridW % vitMergerWindowSize - let numWindowsH = (llmGridH + padH) / vitMergerWindowSize - let numWindowsW = (llmGridW + padW) / vitMergerWindowSize - - var indexPadded = padded( - index, - widths: [0, .init((0, padH)), .init((0, padW))], - mode: .constant, - value: MLXArray(-100, dtype: index.dtype)) - - // Reshape and transpose for window creation - indexPadded = indexPadded.reshaped( - gridT, - numWindowsH, - vitMergerWindowSize, - numWindowsW, - vitMergerWindowSize - ) - - indexPadded = - indexPadded - .transposed(0, 1, 3, 2, 4) - .reshaped( - gridT, - numWindowsH * numWindowsW, - vitMergerWindowSize, - vitMergerWindowSize - ) - - // Process sequence lengths and indices - let seqlens = sum(indexPadded .!= -100, axes: [2, 3]).reshaped(-1) - indexPadded = indexPadded.reshaped(-1) - - var indices = [Int]() - for (i, v) in indexPadded.asArray(Int.self).enumerated() { - if v != -100 { - indices.append(v) - } - } - - let indexNew = MLXArray(indices) - - // Update window index and cumulative sequence lengths - windowIndex.append(indexNew + windowIndexId) - let cuSeqlensTmp = - cumsum(seqlens, axis: 0) * spatialMergeUnit + (cuWindowSeqlens.last ?? 0) - - cuWindowSeqlens.append(contentsOf: cuSeqlensTmp.asArray(Int.self)) - windowIndexId += [gridT * llmGridH * llmGridW] - } - - // Create final arrays - let finalWindowIndex = concatenated(windowIndex, axis: 0) - let finalCuWindowSeqlens = MLXArray(cuWindowSeqlens) - return (finalWindowIndex, finalCuWindowSeqlens) - } - - private func getCuSeqlens(_ gridThw: [THW]) -> MLXArray { - var cuSeqlens = [MLXArray]() - - // Calculate cumulative sequence lengths for each item in batch - for row in gridThw { - let seqLen = row.h * row.w - let repeatedLen = repeated(MLXArray(seqLen), count: row.t, axis: 0) - cuSeqlens.append(repeatedLen) - } - - // Concatenate and process all sequence lengths - var result = concatenated(cuSeqlens, axis: 0) - result = cumsum(result.asType(.int32), axis: 0) - - var r = padded(result, width: .init((1, 0))) - - // Add leading zero for offset calculation - return r - } - - private func isMLXWeight(_ array: MLXArray) -> Bool { - if array.ndim != 4 && array.ndim != 5 { - return false - } - - if array.dim(-1) == 3 { - return true - } - - let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) - return outChannels >= kH && outChannels >= kW && kH == kW + return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) } func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { @@ -709,9 +300,24 @@ private enum Vision { return sanitizedWeights } + + private func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4 && array.ndim != 5 { + return false + } + + if array.dim(-1) == 3 { + return true + } + + let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) + return outChannels >= kH && outChannels >= kW && kH == kW + } } } +// MARK: - Main Model + public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel @@ -752,41 +358,12 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { return mergeInputIdsWithImageFeatures( inputIds: inputIds, inputEmbeds: inputEmbeds, - imageFeatures: hiddenStates + imageFeatures: hiddenStates, + imageTokenId: config.baseConfiguration.imageTokenId, + videoTokenId: config.baseConfiguration.videoTokenId ) } - private func mergeInputIdsWithImageFeatures( - inputIds: MLXArray, - inputEmbeds: MLXArray, - imageFeatures: MLXArray - ) -> MLXArray { - let imageTokenIndex = config.baseConfiguration.imageTokenId - let videoTokenIndex = config.baseConfiguration.videoTokenId - - var imageIndices = [Int]() - for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == imageTokenIndex || v == videoTokenIndex { - imageIndices.append(i) - } - } - - // Make sure shapes match before assignment - var result = inputEmbeds - if result.ndim == 2 { - result = result[.newAxis, 0..., 0...] - } - - if imageFeatures.ndim == 2 { - let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] - result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures - } else { - result[0..., MLXArray(imageIndices), 0...] = imageFeatures - } - - return result - } - public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws -> PrepareResult { @@ -841,10 +418,19 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { } } +// MARK: - Processor + +/// Qwen25VL VLM `UserInputProcessor`. +/// +/// This is meant to be used with ``Qwen25VL`` and is typically created by ``VLMModelFactory``. +/// +public typealias Qwen25VLProcessor = QwenVLProcessor + // MARK: - Configuration +/// Configuration for ``Qwen25VL`` public struct Qwen25VLConfiguration: Codable, Sendable { - public struct TextConfiguration: Codable, Sendable { + public struct TextConfiguration: Codable, Sendable, QwenVLTextConfigurable { public let modelType: String public let hiddenSize: Int public let hiddenLayers: Int @@ -881,7 +467,7 @@ public struct Qwen25VLConfiguration: Codable, Sendable { } } - public struct VisionConfiguration: Codable, Sendable { + public struct VisionConfiguration: Codable, Sendable, QwenVLVisionConfigurable { public let depth: Int public let hiddenSize: Int public let intermediateSize: Int @@ -917,7 +503,7 @@ public struct Qwen25VLConfiguration: Codable, Sendable { } } - public struct BaseConfiguration: Codable, Sendable { + public struct BaseConfiguration: Codable, Sendable, QwenVLBaseConfiguration { public let modelType: String public let vocabularySize: Int public let imageTokenId: Int @@ -956,15 +542,9 @@ public struct Qwen25VLConfiguration: Codable, Sendable { } } -// MARK: - Processor - -/// Qwen25VL VLM `UserInputProcessor`. -/// -/// This is meant to be used with ``Qwen25VL`` and is typically created by ``VLMModelFactory``. -/// -public typealias Qwen25VLProcessor = QwenVLProcessor +// MARK: - Processor Configuration -// Configuration for ``Qwen25VLProcessor`` +/// Configuration for ``Qwen25VLProcessor`` public struct Qwen25VLProcessorConfiguration: QwenVLProcessorConfiguration { public let imageMean: [CGFloat] public let imageStd: [CGFloat] @@ -983,12 +563,4 @@ public struct Qwen25VLProcessorConfiguration: QwenVLProcessorConfiguration { case patchSize = "patch_size" case temporalPatchSize = "temporal_patch_size" } - - private var chatTemplate: String { - "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" - } - - public func applyChatTemplate(messages: [Message], tokenizer: any Tokenizer) throws -> [Int] { - return try tokenizer.applyChatTemplate(messages: messages, chatTemplate: chatTemplate) - } } diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 7e168635..2b21028e 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -11,155 +11,33 @@ import MLXLMCommon import MLXNN import Tokenizers -// MARK: - Common - -/// Rotates half the hidden dims of the input -private func rotateHalf(_ x: MLXArray) -> MLXArray { - let index = x.dim(-1) / 2 - let x1 = x[.ellipsis, 0 ..< index] - let x2 = x[.ellipsis, index...] - return concatenated([-x2, x1], axis: -1) -} - // MARK: - Language private enum Language { - - /// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors - static private func applyMultimodalRotaryPositionEmbedding( - q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, - positionIds: MLXArray, mropeSection: [Int] - ) -> (MLXArray, MLXArray) { - var cos = cos[positionIds] - var sin = sin[positionIds] - - cos = - concatenated( - // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] - split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, - axis: -1 - )[0..., .newAxis, 0..., 0...] - - sin = - concatenated( - split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, - axis: -1 - )[0..., .newAxis, 0..., 0...] - - // Apply rotary embedding - let qEmbed = (q * cos) + (rotateHalf(q) * sin) - let kEmbed = (k * cos) + (rotateHalf(k) * sin) - return (qEmbed, kEmbed) - } - - fileprivate class Attention: Module { - - let heads: Int - let kvHeads: Int - let headDim: Int - let scale: Float - let mropeSection: [Int] - - @ModuleInfo(key: "q_proj") var wq: Linear - @ModuleInfo(key: "k_proj") var wk: Linear - @ModuleInfo(key: "v_proj") var wv: Linear - @ModuleInfo(key: "o_proj") var wo: Linear - - @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE - + fileprivate class Attention: QwenVLLanguage.Attention { public init(_ args: Qwen2VLConfiguration.TextConfiguration) { - let dim = args.hiddenSize - self.heads = args.attentionHeads - self.kvHeads = args.kvHeads - self.headDim = dim / heads - self.scale = pow(Float(headDim), -0.5) - - self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) - self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) - self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) - self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) - - if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { - // mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist() - self.mropeSection = sequence(state: (0, array.makeIterator())) { state in - if let v = state.1.next() { - // note the *2 - state.0 += v * 2 - return state.0 - } else { - return nil - } - }.dropLast() - } else { - fatalError("rope_scaling['mrope_section'] must be an array of integers") - } - - self._rotaryEmbedding.wrappedValue = RoPE( - dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) - } - - public func callAsFunction( - _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? - ) -> MLXArray { - let (B, L) = (x.dim(0), x.dim(1)) - - var queries = wq(x) - var keys = wk(x) - var values = wv(x) - - // prepare the queries, keys and values for the attention computation - queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) - keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) - values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) - - let offset = cache?.offset ?? 0 - let mask = mask?[0..., 0 ..< keys.dim(-2)] - - queries = rotaryEmbedding(queries, offset: offset) - keys = rotaryEmbedding(keys, offset: offset) - - if let cache { - (keys, values) = cache.update(keys: keys, values: values) - } - - let output = MLXFast.scaledDotProductAttention( - queries: queries, keys: keys, values: values, scale: scale, mask: mask + super.init( + hiddenSize: args.hiddenSize, + attentionHeads: args.attentionHeads, + kvHeads: args.kvHeads, + ropeTheta: args.ropeTheta, + ropeTraditional: args.ropeTraditional, + ropeScaling: args.ropeScaling ) - .transposed(0, 2, 1, 3) - .reshaped(B, L, -1) - - return wo(output) - } - } - - fileprivate class MLP: Module, UnaryLayer { - - @ModuleInfo(key: "gate_proj") var gate: Linear - @ModuleInfo(key: "down_proj") var down: Linear - @ModuleInfo(key: "up_proj") var up: Linear - - public init(dimensions: Int, hiddenDimensions: Int) { - self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) - self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) - self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) - } - - public func callAsFunction(_ x: MLXArray) -> MLXArray { - down(silu(gate(x)) * up(x)) } } fileprivate class Qwen2VLDecoderLayer: Module { - @ModuleInfo(key: "self_attn") var attention: Attention - let mlp: MLP + let mlp: QwenVLLanguage.MLP @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm public init(_ args: Qwen2VLConfiguration.TextConfiguration) { self._attention.wrappedValue = Attention(args) - self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self.mlp = QwenVLLanguage.MLP( + dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) self._inputLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) self._postAttentionLayerNorm.wrappedValue = RMSNorm( @@ -178,7 +56,6 @@ private enum Language { } fileprivate class Qwen2Model: Module { - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding fileprivate let layers: [Qwen2VLDecoderLayer] @@ -252,154 +129,7 @@ private enum Language { // MARK: - Vision private enum Vision { - - static fileprivate func applyMultimodalRotaryPositionEmbedding( - _ tensor: MLXArray, freqs: MLXArray - ) -> MLXArray { - var cos = cos(freqs) - var sin = sin(freqs) - - cos = expandedDimensions(cos, axis: 1) - cos = tiled(cos, repetitions: [1, 1, 2]) - cos = expandedDimensions(cos, axis: 0) - - sin = expandedDimensions(sin, axis: 1) - sin = tiled(sin, repetitions: [1, 1, 2]) - sin = expandedDimensions(sin, axis: 0) - - let output = (tensor * cos) + (rotateHalf(tensor) * sin) - return output.asType(tensor.dtype) - } - - fileprivate class VisionRotaryEmbedding { - let dimensions: Int - let theta: Float - let inverseFreq: MLXArray - - init(dimensions: Int, theta: Float) { - self.dimensions = dimensions - self.theta = theta - let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions - self.inverseFreq = 1.0 / pow(theta, p) - } - - func callAsFunction(sequenceLength: Int) -> MLXArray { - let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) - let freqs = outer(seq, inverseFreq) - return freqs - } - } - - fileprivate class PatchEmbed: Module, UnaryLayer { - @ModuleInfo var proj: Conv3d - - let patchSize: Int - let temporalPatchSize: Int - let inChannels: Int - let embedDimensions: Int - - init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, embedDimensions: Int) { - self.patchSize = patchSize - self.temporalPatchSize = temporalPatchSize - self.inChannels = inChannels - self.embedDimensions = embedDimensions - - let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) - self._proj.wrappedValue = Conv3d( - inputChannels: inChannels, - outputChannels: embedDimensions, - kernelSize: kernelSize, - stride: kernelSize, - bias: false - ) - } - - func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { - var hiddenStates = hiddenStates.reshaped( - -1, inChannels, temporalPatchSize, patchSize, patchSize - ).movedAxis(source: 1, destination: 4) - - hiddenStates = proj(hiddenStates) - hiddenStates = hiddenStates.reshaped(-1, embedDimensions) - return hiddenStates - } - } - - fileprivate class PatchMerger: Module, UnaryLayer { - let hiddenSize: Int - @ModuleInfo(key: "ln_q") var layerNormQ: LayerNorm - @ModuleInfo var mlp: (Linear, GELU, Linear) - - init(dimensions: Int, contextDimensions: Int, spatialMergeSize: Int) { - self.hiddenSize = contextDimensions * (spatialMergeSize * spatialMergeSize) - self._layerNormQ.wrappedValue = LayerNorm(dimensions: contextDimensions, eps: 1e-6) - self.mlp = ( - Linear(hiddenSize, hiddenSize), - GELU(), - Linear(hiddenSize, dimensions) - ) - } - - func callAsFunction(_ x: MLXArray) -> MLXArray { - var x = layerNormQ(x).reshaped(-1, hiddenSize) - x = mlp.0(x) - x = mlp.1(x) - x = mlp.2(x) - return x - } - } - - fileprivate class Attention: Module { - - let numHeads: Int - let scale: Float - - @ModuleInfo(key: "qkv") var qkv: Linear - @ModuleInfo(key: "proj") var proj: Linear - - public init(dims: Int, numHeads: Int) { - self.numHeads = numHeads - let headDim = dims / numHeads - self.scale = pow(Float(headDim), -0.5) - - self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) - self._proj.wrappedValue = Linear(dims, dims) - } - - public func callAsFunction( - _ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray - ) -> MLXArray { - let sequenceLength = x.dim(0) - let B = frames[0].t - let L = sequenceLength / B - - let qkv = qkv(x) - let s = split(qkv, parts: 3, axis: -1) - var (q, k, v) = (s[0], s[1], s[2]) - - q = q.reshaped(sequenceLength, numHeads, -1) - k = k.reshaped(sequenceLength, numHeads, -1) - v = v.reshaped(sequenceLength, numHeads, -1) - - q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) - k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) - - q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - - let output = MLXFast.scaledDotProductAttention( - queries: q, keys: k, values: v, scale: scale, mask: nil - ) - .transposed(0, 2, 1, 3) - .reshaped(sequenceLength, -1) - - return proj(output) - } - } - fileprivate class MLP: Module, UnaryLayer { - @ModuleInfo var activation: GELU @ModuleInfo var fc1: Linear @ModuleInfo var fc2: Linear @@ -416,17 +146,16 @@ private enum Vision { } fileprivate class Qwen2VLVisionBlock: Module { - @ModuleInfo var norm1: LayerNorm @ModuleInfo var norm2: LayerNorm - @ModuleInfo(key: "attn") var attention: Attention + @ModuleInfo(key: "attn") var attention: QwenVLVision.Attention @ModuleInfo var mlp: MLP public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { self.norm1 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) self.norm2 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) - self._attention.wrappedValue = Attention( + self._attention.wrappedValue = QwenVLVision.Attention( dims: config.embedDimensions, numHeads: config.numHeads) let mlpHiddenDimensions = Int(Float(config.embedDimensions) * config.mlpRatio) @@ -450,31 +179,31 @@ private enum Vision { } fileprivate class VisionModel: Module { - - @ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed - @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding + @ModuleInfo(key: "patch_embed") var patchEmbed: QwenVLVision.PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: + QwenVLVision.VisionRotaryEmbedding @ModuleInfo(key: "blocks") var blocks: [Qwen2VLVisionBlock] - @ModuleInfo(key: "merger") var patchMerger: PatchMerger + @ModuleInfo(key: "merger") var patchMerger: QwenVLVision.PatchMerger let spatialMergeSize: Int public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { self.spatialMergeSize = config.spatialMergeSize - self._patchEmbed.wrappedValue = PatchEmbed( + self._patchEmbed.wrappedValue = QwenVLVision.PatchEmbed( patchSize: config.patchSize, temporalPatchSize: config.temporalPatchSize, inChannels: config.inChannels, embedDimensions: config.embedDimensions) let headDimensions = config.embedDimensions / config.numHeads - self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding( + self._rotaryPositionEmbedding.wrappedValue = QwenVLVision.VisionRotaryEmbedding( dimensions: headDimensions / 2, theta: 10_000) self._blocks.wrappedValue = (0 ..< config.depth).map { _ in Qwen2VLVisionBlock(config) } - self._patchMerger.wrappedValue = PatchMerger( + self._patchMerger.wrappedValue = QwenVLVision.PatchMerger( dimensions: config.hiddenSize, contextDimensions: config.embedDimensions, spatialMergeSize: 2) } @@ -517,7 +246,7 @@ private enum Vision { let indices = concatenated(positionIds, axis: 0) let maxFrameSize = frames.lazy.map { max($0.h, $0.w) }.max() ?? 0 - let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxFrameSize)[ + let rotaryPositionEmbedFull = rotaryPositionEmbedding(maxFrameSize)[ indices] return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) @@ -538,19 +267,6 @@ private enum Vision { return patchMerger(hiddenStates) } - private func isMLXWeight(_ array: MLXArray) -> Bool { - if array.ndim != 4, array.ndim != 5 { - return false - } - - if array.dim(-1) == 3 { - return true - } - - let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) - return outChannels >= kH && outChannels >= kW && kH == kW - } - func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitizedWeights = [String: MLXArray]() @@ -575,6 +291,19 @@ private enum Vision { return sanitizedWeights } + + private func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4, array.ndim != 5 { + return false + } + + if array.dim(-1) == 3 { + return true + } + + let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) + return outChannels >= kH && outChannels >= kW && kH == kW + } } } @@ -630,36 +359,11 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { // Insert special image tokens in the input_ids return mergeInputIdsWithImageFeatures( - inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates) - } - - private func mergeInputIdsWithImageFeatures( - inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray - ) -> MLXArray { - let imageTokenIndex = config.baseConfiguration.imageTokenId - let videoTokenIndex = config.baseConfiguration.videoTokenId - - var imageIndices = [Int]() - for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == imageTokenIndex || v == videoTokenIndex { - imageIndices.append(i) - } - } - - // Make sure shapes match before assignment - var result = inputEmbeds - if result.ndim == 2 { - result = result[.newAxis, 0..., 0...] - } - - if imageFeatures.ndim == 2 { - let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] - result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures - } else { - result[0..., MLXArray(imageIndices), 0...] = imageFeatures - } - - return result + inputIds: inputIds, + inputEmbeds: inputEmbeds, + imageFeatures: hiddenStates, + imageTokenId: config.baseConfiguration.imageTokenId, + videoTokenId: config.baseConfiguration.videoTokenId) } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws @@ -718,7 +422,6 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { }) ) } - } // MARK: - Configuration @@ -726,7 +429,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { /// Configuration for ``Qwen2VL`` public struct Qwen2VLConfiguration: Codable, Sendable { - public struct TextConfiguration: Codable, Sendable { + public struct TextConfiguration: Codable, Sendable, QwenVLTextConfigurable { public let modelType: String public let hiddenSize: Int public let hiddenLayers: Int @@ -763,7 +466,7 @@ public struct Qwen2VLConfiguration: Codable, Sendable { } } - public struct VisionConfiguration: Codable, Sendable { + public struct VisionConfiguration: Codable, Sendable, QwenVLVisionConfigurable { public let depth: Int public let embedDimensions: Int public let hiddenSize: Int @@ -793,7 +496,7 @@ public struct Qwen2VLConfiguration: Codable, Sendable { } } - public struct BaseConfiguration: Codable, Sendable { + public struct BaseConfiguration: Codable, Sendable, QwenVLBaseConfiguration { public let modelType: String public let vocabularySize: Int public let imageTokenId: Int @@ -860,8 +563,4 @@ public struct Qwen2VLProcessorConfiguration: QwenVLProcessorConfiguration { case patchSize = "patch_size" case temporalPatchSize = "temporal_patch_size" } - - public func applyChatTemplate(messages: [Message], tokenizer: any Tokenizer) throws -> [Int] { - return try tokenizer.applyChatTemplate(messages: messages) - } } diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift new file mode 100644 index 00000000..70b2dd4c --- /dev/null +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -0,0 +1,632 @@ +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Common Utilities + +/// Rotates half the hidden dims of the input +public func rotateHalf(_ x: MLXArray) -> MLXArray { + let index = x.dim(-1) / 2 + let x1 = x[.ellipsis, 0 ..< index] + let x2 = x[.ellipsis, index...] + return concatenated([-x2, x1], axis: -1) +} + +// MARK: - Language Model Components + +public enum QwenVLLanguage { + /// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors + public static func applyMultimodalRotaryPositionEmbedding( + q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, + positionIds: MLXArray, mropeSection: [Int] + ) -> (MLXArray, MLXArray) { + var cos = cos[positionIds] + var sin = sin[positionIds] + + cos = + concatenated( + // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] + split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + sin = + concatenated( + split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + // Apply rotary embedding + let qEmbed = (q * cos) + (rotateHalf(q) * sin) + let kEmbed = (k * cos) + (rotateHalf(k) * sin) + return (qEmbed, kEmbed) + } + + public class Attention: Module { + let heads: Int + let kvHeads: Int + let headDim: Int + let scale: Float + let mropeSection: [Int] + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE + + public init(hiddenSize: Int, attentionHeads: Int, kvHeads: Int, ropeTheta: Float, ropeTraditional: Bool, ropeScaling: [String: StringOrNumber]?) { + self.heads = attentionHeads + self.kvHeads = kvHeads + self.headDim = hiddenSize / attentionHeads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(hiddenSize, heads * headDim, bias: true) + self._wk.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true) + self._wv.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true) + self._wo.wrappedValue = Linear(heads * headDim, hiddenSize, bias: false) + + if let v = ropeScaling?["mrope_section"], let array = v.asInts() { + // mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist() + self.mropeSection = sequence(state: (0, array.makeIterator())) { state in + if let v = state.1.next() { + // note the *2 + state.0 += v * 2 + return state.0 + } else { + return nil + } + }.dropLast() + } else { + fatalError("rope_scaling['mrope_section'] must be an array of integers") + } + + self._rotaryEmbedding.wrappedValue = RoPE( + dimensions: headDim, traditional: ropeTraditional, base: ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + + let offset = cache?.offset ?? 0 + let mask = mask?[0..., 0 ..< keys.dim(-2)] + + queries = rotaryEmbedding(queries, offset: offset) + keys = rotaryEmbedding(keys, offset: offset) + + if let cache { + (keys, values) = cache.update(keys: keys, values: values) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } + } + + public class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } +} + +// MARK: - Vision Model Components + +public enum QwenVLVision { + public static func applyMultimodalRotaryPositionEmbedding( + _ tensor: MLXArray, freqs: MLXArray + ) -> MLXArray { + var cos = cos(freqs) + var sin = sin(freqs) + + cos = expandedDimensions(cos, axis: 1) + cos = tiled(cos, repetitions: [1, 1, 2]) + cos = expandedDimensions(cos, axis: 0) + + sin = expandedDimensions(sin, axis: 1) + sin = tiled(sin, repetitions: [1, 1, 2]) + sin = expandedDimensions(sin, axis: 0) + + let output = (tensor * cos) + (rotateHalf(tensor) * sin) + return output.asType(tensor.dtype) + } + + public class VisionRotaryEmbedding: Module { + let dimensions: Int + let theta: Float + + public init(dimensions: Int, theta: Float = 10000.0) { + self.dimensions = dimensions + self.theta = theta + } + + public func callAsFunction(_ sequenceLength: Int) -> MLXArray { + let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions + let inverseFreq = 1.0 / pow(theta, p) + let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) + return outer(seq, inverseFreq) + } + } + + public class PatchEmbed: Module { + @ModuleInfo var proj: Conv3d + + let patchSize: Int + let temporalPatchSize: Int + let inChannels: Int + let embedDimensions: Int + + public init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, embedDimensions: Int) { + self.patchSize = patchSize + self.temporalPatchSize = temporalPatchSize + self.inChannels = inChannels + self.embedDimensions = embedDimensions + + let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) + self._proj.wrappedValue = Conv3d( + inputChannels: inChannels, + outputChannels: embedDimensions, + kernelSize: kernelSize, + stride: kernelSize, + bias: false + ) + } + + public func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { + var hiddenStates = hiddenStates.reshaped( + -1, inChannels, temporalPatchSize, patchSize, patchSize + ).movedAxis(source: 1, destination: 4) + + hiddenStates = proj(hiddenStates) + hiddenStates = hiddenStates.reshaped(-1, embedDimensions) + return hiddenStates + } + } + + class PatchMerger: Module { + let hiddenSize: Int + + @ModuleInfo(key: "ln_q") var layerNormQ: RMSNorm + @ModuleInfo var mlp: (Linear, GELU, Linear) + + init(dimensions: Int, contextDimensions: Int, spatialMergeSize: Int = 2) { + self.hiddenSize = contextDimensions * (spatialMergeSize * spatialMergeSize) + self._layerNormQ.wrappedValue = RMSNorm(dimensions: contextDimensions, eps: 1e-6) + self.mlp = ( + Linear(hiddenSize, hiddenSize), + GELU(), + Linear(hiddenSize, dimensions) + ) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = layerNormQ(x).reshaped(-1, hiddenSize) + x = mlp.0(x) + x = mlp.1(x) + x = mlp.2(x) + return x + } + } + + class Attention: Module { + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "qkv") var qkv: Linear + @ModuleInfo(key: "proj") var proj: Linear + + public init(dims: Int, numHeads: Int) { + self.numHeads = numHeads + let headDim = dims / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) + self._proj.wrappedValue = Linear(dims, dims) + } + + public func callAsFunction( + _ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + let sequenceLength = x.dim(0) + let B = frames[0].t + let L = sequenceLength / B + + let qkv = qkv(x) + let s = split(qkv, parts: 3, axis: -1) + var (q, k, v) = (s[0], s[1], s[2]) + + q = q.reshaped(sequenceLength, numHeads, -1) + k = k.reshaped(sequenceLength, numHeads, -1) + v = v.reshaped(sequenceLength, numHeads, -1) + + q = QwenVLVision.applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) + k = QwenVLVision.applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) + + q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + + let output = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: nil + ) + .transposed(0, 2, 1, 3) + .reshaped(sequenceLength, -1) + + return proj(output) + } + } +} + +// MARK: - Model Configuration Base Classes + +/// Base protocol for Qwen VL configuration +public protocol QwenVLBaseConfiguration: Codable, Sendable { + var vocabularySize: Int { get } + var imageTokenId: Int { get } + var videoTokenId: Int { get } + var hiddenSize: Int { get } +} + +/// Base protocol for text configuration +public protocol QwenVLTextConfigurable: Codable, Sendable { + var hiddenSize: Int { get } + var hiddenLayers: Int { get } + var intermediateSize: Int { get } + var attentionHeads: Int { get } + var rmsNormEps: Float { get } + var vocabularySize: Int { get } + var kvHeads: Int { get } + var ropeTheta: Float { get } + var ropeTraditional: Bool { get } + var ropeScaling: [String: StringOrNumber]? { get } + var tieWordEmbeddings: Bool { get } +} + +/// Base protocol for vision configuration +public protocol QwenVLVisionConfigurable: Codable, Sendable { + var patchSize: Int { get } + var inChannels: Int { get } + var temporalPatchSize: Int { get } + var spatialMergeSize: Int { get } +} + +// MARK: - Common Processor Configuration + +/// Configuration for the Qwen VL processor +public protocol QwenVLProcessorConfiguration: Codable, Sendable { + var imageMean: [CGFloat] { get } + var imageStd: [CGFloat] { get } + var maxPixels: Int { get } + var minPixels: Int { get } + var mergeSize: Int { get } + var patchSize: Int { get } + var temporalPatchSize: Int { get } + + var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { get } + var imageStdTuple: (CGFloat, CGFloat, CGFloat) { get } +} + +// Default implementation for common properties +extension QwenVLProcessorConfiguration { + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } +} + +// MARK: - Common VLM Model Functions + +public extension VLMModel where Self: Module, Self: KVCacheDimensionProvider { + /// Common implementation for merging input IDs with image features + func mergeInputIdsWithImageFeatures( + inputIds: MLXArray, + inputEmbeds: MLXArray, + imageFeatures: MLXArray, + imageTokenId: Int, + videoTokenId: Int + ) -> MLXArray { + var imageIndices = [Int]() + for (i, v) in inputIds.asArray(Int.self).enumerated() { + if v == imageTokenId || v == videoTokenId { + imageIndices.append(i) + } + } + + // Make sure shapes match before assignment + var result = inputEmbeds + if result.ndim == 2 { + result = result[.newAxis, 0..., 0...] + } + + if imageFeatures.ndim == 2 { + let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] + result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures + } else { + result[0..., MLXArray(imageIndices), 0...] = imageFeatures + } + + return result + } + + /// Helper method to determine if an array is in MLX weight format + func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4, array.ndim != 5 { + return false + } + + if array.dim(-1) == 3 { + return true + } + + let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) + return outChannels >= kH && outChannels >= kW && kH == kW + } + + /// Helper method to sanitize PyTorch weights for MLX + func sanitizeVisionWeights(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitizedWeights = [String: MLXArray]() + + for (k, v) in weights { + if k.contains("position_id") { + // Remove unused position_ids + continue + } else if k.contains("patch_embed.proj.weight") { + // PyTorch conv2d weight tensors have shape: + // [B, out_channels, in_channels, kH, KW] + // MLX conv2d expects the weight be of shape: + // [B, out_channels, kH, KW, in_channels] + if isMLXWeight(v) { + sanitizedWeights[k] = v + } else { + sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1) + } + } else { + sanitizedWeights[k] = v + } + } + + return sanitizedWeights + } +} + +// MARK: - Common Processor Implementation + +/// Base implementation for Qwen VL processors +public class QwenVLProcessor: UserInputProcessor { + + public let config: Config + public let tokenizer: any Tokenizer + + public init(config: Config, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) + throws + -> (Int, Int) + { + if height < factor { + throw VLMError.imageProcessingFailure( + "height: \(height) must be larger than factor: \(factor)") + } + if width < factor { + throw VLMError.imageProcessingFailure( + "width: \(width) must be larger than factor: \(factor)") + } + if max(height, width) / min(height, width) > 200 { + throw VLMError.imageProcessingFailure( + "absolute aspect ratio must be smaller than 200: \(width)x\(height)") + } + + var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) + var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) + + if hBar * wBar > maxPixels { + let beta = sqrt(Float(height * width) / Float(maxPixels)) + hBar = Int(floor(Float(height) / beta / Float(factor))) * factor + wBar = Int(floor(Float(width) / beta / Float(factor))) * factor + } else if hBar * wBar < minPixels { + let beta = sqrt(Float(minPixels) / Float(height * width)) + hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor + wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor + } + return (hBar, wBar) + } + + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( + MLXArray, THW + ) { + // first apply the user requested resizing, etc. if any + let images = images.map { MediaProcessing.apply($0, processing: processing) } + + // image_processing_qwen2_vl._preprocess + let size = images[0].extent.size + + let (resizedHeight, resizedWidth) = try targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.minPixels, maxPixels: config.maxPixels) + + // Apply the calculated dimensions + let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + + let processedImages = + try images + .map { + MediaProcessing.inSRGBToneCurveSpace($0) + } + .map { + MediaProcessing.resampleBicubic($0, to: resizedSize) + } + .map { + MediaProcessing.normalize( + $0, mean: config.imageMeanTuple, std: config.imageStdTuple) + } + .map { + MediaProcessing.asMLXArray($0) + } + + var patches = concatenated(processedImages) + + // Handle temporal dimension + let mod = patches.dim(0) % config.temporalPatchSize + if mod != 0 { + let lastPatch = patches[-1, .ellipsis] + let lastPatchRepeated = tiled( + lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) + patches = concatenated([patches, lastPatchRepeated]) + } + + // Calculate grid dimensions + let channel = patches.dim(1) + let gridT = patches.dim(0) / self.config.temporalPatchSize + let gridH = patches.dim(2) / self.config.patchSize + let gridW = patches.dim(3) / self.config.patchSize + + patches = patches.reshaped( + gridT, + config.temporalPatchSize, + channel, + gridH / config.mergeSize, + config.mergeSize, + config.patchSize, + gridW / config.mergeSize, + config.mergeSize, + config.patchSize + ) + patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) + + let flattenedPatches = patches.reshaped( + gridT * gridH * gridW, + channel * config.temporalPatchSize * config.patchSize * config.patchSize + ) + + return (flattenedPatches, .init(gridT, gridH, gridW)) + } + + public func prepare(input: MLXLMCommon.UserInput) async throws -> MLXLMCommon.LMInput { + let messages = input.prompt.asMessages() + var promptTokens = try tokenizer.applyChatTemplate(messages: messages) + + // Text-only input + if input.images.isEmpty, input.videos.isEmpty { + return LMInput(tokens: MLXArray(promptTokens)) + } + + // Process images if any + var processedImage: LMInput.ProcessedImage? + if !input.images.isEmpty { + let imagePixelsAndFrames = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } + let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) + processedImage = LMInput.ProcessedImage( + pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) + if let imageFrames = processedImage?.frames { + promptTokens = try replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>") + } + } + + // Process videos if any + var processedVideo: LMInput.ProcessedVideo? + if !input.videos.isEmpty { + var videosAsImageSequences = [[CIImage]]() + for video in input.videos { + if let imageSequence = try? await MediaProcessing.asCIImageSequence( + video.asAVAsset(), samplesPerSecond: 2) + { + videosAsImageSequences.append(imageSequence) + } + } + let videoPixelsAndFrames = try videosAsImageSequences.map { + try preprocess(images: $0, processing: input.processing) + } + let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) + processedVideo = LMInput.ProcessedVideo( + pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) + if let videoFrames = processedVideo?.frames { + promptTokens = try replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>") + } + } + + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + return LMInput( + text: .init(tokens: promptArray, mask: mask), + image: processedImage, + video: processedVideo) + } + + func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) + throws -> [Int] + { + // Replace single padding token with correct number for each image or video frame + let placeholderTokens = try tokenizer.encode( + text: "<|vision_start|>\(paddingToken)<|vision_end|>") + let placeholderRanges = promptTokens.ranges(of: placeholderTokens) + guard placeholderRanges.count == frames.count else { + throw VLMError.processing( + "Number of placeholder tokens does not match number of frames") + } + let mergeLength = config.mergeSize * config.mergeSize + let replacementSequences = try frames.map { frame in + let paddingCount = frame.product / mergeLength + return try tokenizer.encode( + text: + "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" + ) + } + // Build the final array + var result: [Int] = [] + var currentIndex = promptTokens.startIndex + for (range, replacement) in zip(placeholderRanges, replacementSequences) { + // Add tokens before the placeholder + result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) + // Add replacement sequence + result.append(contentsOf: replacement) + currentIndex = range.upperBound + } + // Add any remaining tokens after the last replacement + if currentIndex < promptTokens.endIndex { + result.append(contentsOf: promptTokens[currentIndex...]) + } + return result + } +} diff --git a/Libraries/MLXVLM/Models/QwenVLProcessor.swift b/Libraries/MLXVLM/Models/QwenVLProcessor.swift deleted file mode 100644 index 4d7a0028..00000000 --- a/Libraries/MLXVLM/Models/QwenVLProcessor.swift +++ /dev/null @@ -1,228 +0,0 @@ -import CoreImage -import Foundation -import MLX -import MLXLMCommon -import Tokenizers - -public protocol QwenVLProcessorConfiguration: Codable, Sendable { - var imageMean: [CGFloat] { get } - var imageStd: [CGFloat] { get } - var maxPixels: Int { get } - var minPixels: Int { get } - var mergeSize: Int { get } - var patchSize: Int { get } - var temporalPatchSize: Int { get } - - var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { get } - var imageStdTuple: (CGFloat, CGFloat, CGFloat) { get } - - func applyChatTemplate(messages: [Message], tokenizer: any Tokenizer) throws -> [Int] -} - -// Default implementation for common properties -extension QwenVLProcessorConfiguration { - public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { - (imageMean[0], imageMean[1], imageMean[2]) - } - public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { - (imageStd[0], imageStd[1], imageStd[2]) - } -} - -// Base processor class -public class QwenVLProcessor: UserInputProcessor { - private let config: Config - private let tokenizer: any Tokenizer - - public init(_ config: Config, tokenizer: any Tokenizer) { - self.config = config - self.tokenizer = tokenizer - } - - private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) - throws -> (Int, Int) - { - if height < factor { - throw VLMError.imageProcessingFailure( - "height: \(height) must be larger than factor: \(factor)") - } - if width < factor { - throw VLMError.imageProcessingFailure( - "width: \(width) must be larger than factor: \(factor)") - } - if max(height, width) / min(height, width) > 200 { - throw VLMError.imageProcessingFailure( - "absolute aspect ratio must be smaller than 200: \(width)x\(height)") - } - - var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) - var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) - - if hBar * wBar > maxPixels { - let beta = sqrt(Float(height * width) / Float(maxPixels)) - hBar = Int(floor(Float(height) / beta / Float(factor))) * factor - wBar = Int(floor(Float(width) / beta / Float(factor))) * factor - } else if hBar * wBar < minPixels { - let beta = sqrt(Float(minPixels) / Float(height * width)) - hBar = Int(floor(Float(height) * beta / Float(factor))) * factor - wBar = Int(floor(Float(width) * beta / Float(factor))) * factor - } - return (hBar, wBar) - } - - public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( - MLXArray, THW - ) { - // first apply the user requested resizing, etc. if any - let images = images.map { MediaProcessing.apply($0, processing: processing) } - - // image_processing_qwen2_vl._preprocess - - let size = images[0].extent.size - let (resizedHeight, resizedWidth) = try targetSize( - height: Int(size.height), width: Int(size.width), - factor: config.patchSize * config.mergeSize, - minPixels: config.minPixels, maxPixels: config.maxPixels) - let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) - - let processedImages = - try images - .map { - MediaProcessing.inSRGBToneCurveSpace($0) - } - .map { - MediaProcessing.resampleBicubic($0, to: resizedSize) - } - .map { - MediaProcessing.normalize( - $0, mean: config.imageMeanTuple, std: config.imageStdTuple) - } - .map { - MediaProcessing.asMLXArray($0) - } - - var patches = concatenated(processedImages) - let mod = patches.dim(0) % config.temporalPatchSize - if mod != 0 { - let lastPatch = patches[-1, .ellipsis] - let lastPatchRepeated = tiled( - lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) - patches = concatenated([patches, lastPatchRepeated]) - } - let channel = patches.dim(1) - let gridT = patches.dim(0) / self.config.temporalPatchSize - let gridH = resizedHeight / self.config.patchSize - let gridW = resizedWidth / self.config.patchSize - - patches = patches.reshaped( - gridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) - patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) - - let flattenedPatches = patches.reshaped( - gridT * gridH * gridW, - channel * config.temporalPatchSize * config.patchSize * config.patchSize - ) - - return (flattenedPatches, .init(gridT, gridH, gridW)) - } - - public func prepare(input: UserInput) async throws -> LMInput { - let messages = input.prompt.asMessages() - var promptTokens = try config.applyChatTemplate(messages: messages, tokenizer: tokenizer) - - // Text-only input - if input.images.isEmpty, input.videos.isEmpty { - return LMInput(tokens: MLXArray(promptTokens)) - } - - // Process images if any - var processedImage: LMInput.ProcessedImage? - if !input.images.isEmpty { - let imagePixelsAndFrames = try input.images.map { - try preprocess(images: [$0.asCIImage()], processing: input.processing) - } - let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) - processedImage = LMInput.ProcessedImage( - pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) - if let imageFrames = processedImage?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>") - } - } - - // Process videos if any - var processedVideo: LMInput.ProcessedVideo? - if !input.videos.isEmpty { - var videosAsImageSequences = [[CIImage]]() - for video in input.videos { - if let imageSequence = try? await MediaProcessing.asCIImageSequence( - video.asAVAsset(), samplesPerSecond: 2) - { - videosAsImageSequences.append(imageSequence) - } - } - let videoPixelsAndFrames = try videosAsImageSequences.map { - try preprocess(images: $0, processing: input.processing) - } - let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) - processedVideo = LMInput.ProcessedVideo( - pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) - if let videoFrames = processedVideo?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>") - } - } - - let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) - let mask = ones(like: promptArray).asType(.int8) - return LMInput( - text: .init(tokens: promptArray, mask: mask), - image: processedImage, - video: processedVideo) - } - - func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) - throws -> [Int] - { - // Replace single padding token with correct number for each image or video frame - let placeholderTokens = try tokenizer.encode( - text: "<|vision_start|>\(paddingToken)<|vision_end|>") - let placeholderRanges = promptTokens.ranges(of: placeholderTokens) - guard placeholderRanges.count == frames.count else { - throw VLMError.processing( - "Number of placeholder tokens does not match number of frames") - } - let mergeLength = config.mergeSize * config.mergeSize - let replacementSequences = try frames.map { frame in - let paddingCount = frame.product / mergeLength - return try tokenizer.encode( - text: - "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" - ) - } - // Build the final array - var result: [Int] = [] - var currentIndex = promptTokens.startIndex - for (range, replacement) in zip(placeholderRanges, replacementSequences) { - // Add tokens before the placeholder - result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) - // Add replacement sequence - result.append(contentsOf: replacement) - currentIndex = range.upperBound - } - // Add any remaining tokens after the last replacement - if currentIndex < promptTokens.endIndex { - result.append(contentsOf: promptTokens[currentIndex...]) - } - return result - } -} diff --git a/Package.swift b/Package.swift index f0e8bc4a..a2c839a9 100644 --- a/Package.swift +++ b/Package.swift @@ -29,7 +29,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.21.2")), .package( - url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.17") + url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.18") ), .package( url: "https://github.com/apple/swift-async-algorithms", .upToNextMinor(from: "1.0.0")), From 3072ad006af2a8d80439f6eef6b1ae9911870cb1 Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Tue, 4 Mar 2025 16:10:05 -0800 Subject: [PATCH 5/5] apply correct windowing fixes from @DePasqualeOrg --- Libraries/MLXVLM/Models/Qwen25VL.swift | 213 +++++++++++++++++++++++-- Libraries/MLXVLM/Models/Qwen2VL.swift | 55 ++++++- Libraries/MLXVLM/Models/QwenVL.swift | 69 ++------ 3 files changed, 269 insertions(+), 68 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index ab54f4ec..f2a3e232 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -134,6 +134,67 @@ private enum Language { // MARK: - Vision private enum Vision { + fileprivate class Attention: Module { + + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "qkv") var qkv: Linear + @ModuleInfo(key: "proj") var proj: Linear + + public init(dims: Int, numHeads: Int) { + self.numHeads = numHeads + let headDim = dims / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) + self._proj.wrappedValue = Linear(dims, dims) + } + + public func callAsFunction( + _ x: MLXArray, cuSeqlens: MLXArray, rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + let sequenceLength = x.dim(0) + + let qkv = qkv(x) + let s = split(qkv, parts: 3, axis: -1) + var (q, k, v) = (s[0], s[1], s[2]) + + q = q.reshaped(sequenceLength, numHeads, -1) + k = k.reshaped(sequenceLength, numHeads, -1) + v = v.reshaped(sequenceLength, numHeads, -1) + + q = QwenVLVision.applyMultimodalRotaryPositionEmbedding( + q, freqs: rotaryPositionEmbedding) + k = QwenVLVision.applyMultimodalRotaryPositionEmbedding( + k, freqs: rotaryPositionEmbedding) + + // Create attention mask + let attentionMask = full( + [1, sequenceLength, sequenceLength], + values: -Float32.greatestFiniteMagnitude) + + // Update mask for each sequence + for i in 1 ..< cuSeqlens.size { + let start = cuSeqlens[i - 1].item(Int.self) + let end = cuSeqlens[i].item(Int.self) + attentionMask[0..., start ..< end, start ..< end] = MLXArray(0) + } + + q = q.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + k = k.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + v = v.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + + let output = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: attentionMask + ) + .transposed(0, 2, 1, 3) + .reshaped(sequenceLength, -1) + + return proj(output) + } + } + fileprivate class MLP: Module, UnaryLayer { @ModuleInfo(key: "gate_proj") var gate: Linear @ModuleInfo(key: "up_proj") var up: Linear @@ -153,13 +214,13 @@ private enum Vision { fileprivate class Qwen25VLVisionBlock: Module { @ModuleInfo var norm1: LayerNorm @ModuleInfo var norm2: LayerNorm - @ModuleInfo(key: "attn") var attention: QwenVLVision.Attention + @ModuleInfo(key: "attn") var attention: Attention @ModuleInfo var mlp: MLP init(_ config: Qwen25VLConfiguration.VisionConfiguration) { self.norm1 = LayerNorm(dimensions: config.hiddenSize, eps: 1e-6) self.norm2 = LayerNorm(dimensions: config.hiddenSize, eps: 1e-6) - self._attention.wrappedValue = QwenVLVision.Attention( + self._attention.wrappedValue = Attention( dims: config.hiddenSize, numHeads: config.numHeads ) @@ -172,14 +233,14 @@ private enum Vision { func callAsFunction( _ hiddenStates: MLXArray, - frames: [THW], + cuSeqlens: MLXArray, rotaryPositionEmbedding: MLXArray ) -> MLXArray { var hiddenStates = hiddenStates + attention( norm1(hiddenStates), - frames: frames, + cuSeqlens: cuSeqlens, rotaryPositionEmbedding: rotaryPositionEmbedding ) @@ -196,9 +257,18 @@ private enum Vision { @ModuleInfo(key: "merger") var patchMerger: QwenVLVision.PatchMerger let spatialMergeSize: Int + let windowSize: Int + let patchSize: Int + let spatialMergeUnit: Int + let fullAttBlockIndexes: [Int] init(_ config: Qwen25VLConfiguration.VisionConfiguration) { self.spatialMergeSize = config.spatialMergeSize + self.windowSize = config.windowSize + self.patchSize = config.patchSize + self.spatialMergeUnit = config.spatialMergeSize * config.spatialMergeSize + self.fullAttBlockIndexes = config.fullAttBlockIndexes + self._patchEmbed.wrappedValue = QwenVLVision.PatchEmbed( patchSize: config.patchSize, temporalPatchSize: config.temporalPatchSize, @@ -228,13 +298,53 @@ private enum Vision { var hiddenStates = patchEmbed(hiddenStates) var rotaryPositionEmbedding = rotaryPositionEmbedding(frames) - let batchSize = frames.count - for block in blocks { + // Get window indices and sequence lengths + let (windowIndex, cuWindowSeqlens) = getWindowIndex(frames) + + // Reshape and reindex hidden states + let seqLen = hiddenStates.dim(0) + hiddenStates = hiddenStates.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) + hiddenStates = hiddenStates[windowIndex, 0..., 0...] + hiddenStates = hiddenStates.reshaped(seqLen, -1) + + // Reshape and reindex rotary position embeddings + var rotaryPosEmbReshaped = rotaryPositionEmbedding.reshaped( + seqLen / spatialMergeUnit, spatialMergeUnit, -1) + rotaryPosEmbReshaped = rotaryPosEmbReshaped[windowIndex, 0..., 0...] + rotaryPosEmbReshaped = rotaryPosEmbReshaped.reshaped(seqLen, -1) + + // Calculate cumulative sequence lengths for full attention + var cuSeqlens = [0] + for frame in frames { + let seqLen = frame.h * frame.w + cuSeqlens.append( + contentsOf: Array(repeating: seqLen, count: frame.t).map { + cuSeqlens.last! + $0 + }) + } + let cuSeqlensArray = MLXArray(cuSeqlens) + + // Process through blocks + for (i, block) in blocks.enumerated() { + // Use full attention for specific blocks, window attention for others + let cuSeqlensNow = + fullAttBlockIndexes.contains(i) ? cuSeqlensArray : cuWindowSeqlens + hiddenStates = block( - hiddenStates, frames: frames, - rotaryPositionEmbedding: rotaryPositionEmbedding) + hiddenStates, + cuSeqlens: cuSeqlensNow, + rotaryPositionEmbedding: rotaryPosEmbReshaped + ) } - return patchMerger(hiddenStates) + + // Apply patch merger + hiddenStates = patchMerger(hiddenStates) + + // Reorder back to original sequence + let reverseIndices = argSort(windowIndex, axis: 0) + hiddenStates = hiddenStates[reverseIndices, 0...] + + return hiddenStates } private func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray { @@ -276,6 +386,91 @@ private enum Vision { return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) } + func getWindowIndex(_ frames: [THW]) -> (MLXArray, MLXArray) { + var windowIndex = [MLXArray]() + var cuWindowSeqlens = [0] + var windowIndexId = 0 + let vitMergerWindowSize = windowSize / spatialMergeSize / patchSize + + for frame in frames { + let (gridT, gridH, gridW) = frame.values + let llmGridH = gridH / spatialMergeSize + let llmGridW = gridW / spatialMergeSize + + let index = MLXArray(0 ..< (gridT * llmGridH * llmGridW)).reshaped( + gridT, llmGridH, llmGridW) + + let padH = vitMergerWindowSize - llmGridH % vitMergerWindowSize + let padW = vitMergerWindowSize - llmGridW % vitMergerWindowSize + let numWindowsH = (llmGridH + padH) / vitMergerWindowSize + let numWindowsW = (llmGridW + padW) / vitMergerWindowSize + + // Pad the index + let indexPadded = padded( + index, + widths: [[0, 0], [0, padH], [0, padW]], + mode: .constant, + value: MLXArray(-100) + ) + + // Reshape and transpose + let indexReshaped = indexPadded.reshaped( + gridT, + numWindowsH, + vitMergerWindowSize, + numWindowsW, + vitMergerWindowSize + ) + + let indexTransposed = indexReshaped.transposed(0, 1, 3, 2, 4).reshaped( + gridT, + numWindowsH * numWindowsW, + vitMergerWindowSize, + vitMergerWindowSize + ) + + // Calculate sequence lengths + let seqlens = sum(indexTransposed .!= -100, axes: [2, 3]).reshaped(-1) + + // Get valid indices + let indexFlattened = indexTransposed.flattened() + let validIndices = indexFlattened.asArray(Int.self).enumerated() + .filter { $0.element != -100 } + .map { $0.offset } + + let validValues = indexFlattened[MLXArray(validIndices)] + + // Add to window index + windowIndex.append(validValues + windowIndexId) + + // Update cumulative sequence lengths + let cuSeqlensTmp = + cumsum(seqlens, axis: 0) * spatialMergeUnit + cuWindowSeqlens.last! + cuWindowSeqlens.append(contentsOf: cuSeqlensTmp.asArray(Int.self)) + + windowIndexId += gridT * llmGridH * llmGridW + } + + // Concatenate all window indices + let combinedWindowIndex = concatenated(windowIndex, axis: 0) + let cuWindowSeqlensArray = MLXArray(cuWindowSeqlens) + + // Get unique values in cuWindowSeqlens + var seen = Set() + var uniqueIndices = [Int]() + + for (i, value) in cuWindowSeqlens.enumerated() { + if !seen.contains(value) { + seen.insert(value) + uniqueIndices.append(i) + } + } + + let uniqueCuWindowSeqlens = cuWindowSeqlensArray[MLXArray(uniqueIndices)] + + return (combinedWindowIndex, uniqueCuWindowSeqlens) + } + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitizedWeights = [String: MLXArray]() diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 2b21028e..c9ca69e7 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -129,6 +129,57 @@ private enum Language { // MARK: - Vision private enum Vision { + fileprivate class Attention: Module { + + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "qkv") var qkv: Linear + @ModuleInfo(key: "proj") var proj: Linear + + public init(dims: Int, numHeads: Int) { + self.numHeads = numHeads + let headDim = dims / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) + self._proj.wrappedValue = Linear(dims, dims) + } + + public func callAsFunction( + _ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + let sequenceLength = x.dim(0) + let B = frames[0].t + let L = sequenceLength / B + + let qkv = qkv(x) + let s = split(qkv, parts: 3, axis: -1) + var (q, k, v) = (s[0], s[1], s[2]) + + q = q.reshaped(sequenceLength, numHeads, -1) + k = k.reshaped(sequenceLength, numHeads, -1) + v = v.reshaped(sequenceLength, numHeads, -1) + + q = QwenVLVision.applyMultimodalRotaryPositionEmbedding( + q, freqs: rotaryPositionEmbedding) + k = QwenVLVision.applyMultimodalRotaryPositionEmbedding( + k, freqs: rotaryPositionEmbedding) + + q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + + let output = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: nil + ) + .transposed(0, 2, 1, 3) + .reshaped(sequenceLength, -1) + + return proj(output) + } + } + fileprivate class MLP: Module, UnaryLayer { @ModuleInfo var activation: GELU @ModuleInfo var fc1: Linear @@ -148,14 +199,14 @@ private enum Vision { fileprivate class Qwen2VLVisionBlock: Module { @ModuleInfo var norm1: LayerNorm @ModuleInfo var norm2: LayerNorm - @ModuleInfo(key: "attn") var attention: QwenVLVision.Attention + @ModuleInfo(key: "attn") var attention: Attention @ModuleInfo var mlp: MLP public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { self.norm1 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) self.norm2 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) - self._attention.wrappedValue = QwenVLVision.Attention( + self._attention.wrappedValue = Attention( dims: config.embedDimensions, numHeads: config.numHeads) let mlpHiddenDimensions = Int(Float(config.embedDimensions) * config.mlpRatio) diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift index 70b2dd4c..9d15db82 100644 --- a/Libraries/MLXVLM/Models/QwenVL.swift +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -61,7 +61,10 @@ public enum QwenVLLanguage { @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE - public init(hiddenSize: Int, attentionHeads: Int, kvHeads: Int, ropeTheta: Float, ropeTraditional: Bool, ropeScaling: [String: StringOrNumber]?) { + public init( + hiddenSize: Int, attentionHeads: Int, kvHeads: Int, ropeTheta: Float, + ropeTraditional: Bool, ropeScaling: [String: StringOrNumber]? + ) { self.heads = attentionHeads self.kvHeads = kvHeads self.headDim = hiddenSize / attentionHeads @@ -239,54 +242,6 @@ public enum QwenVLVision { return x } } - - class Attention: Module { - let numHeads: Int - let scale: Float - - @ModuleInfo(key: "qkv") var qkv: Linear - @ModuleInfo(key: "proj") var proj: Linear - - public init(dims: Int, numHeads: Int) { - self.numHeads = numHeads - let headDim = dims / numHeads - self.scale = pow(Float(headDim), -0.5) - - self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) - self._proj.wrappedValue = Linear(dims, dims) - } - - public func callAsFunction( - _ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray - ) -> MLXArray { - let sequenceLength = x.dim(0) - let B = frames[0].t - let L = sequenceLength / B - - let qkv = qkv(x) - let s = split(qkv, parts: 3, axis: -1) - var (q, k, v) = (s[0], s[1], s[2]) - - q = q.reshaped(sequenceLength, numHeads, -1) - k = k.reshaped(sequenceLength, numHeads, -1) - v = v.reshaped(sequenceLength, numHeads, -1) - - q = QwenVLVision.applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) - k = QwenVLVision.applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) - - q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) - - let output = MLXFast.scaledDotProductAttention( - queries: q, keys: k, values: v, scale: scale, mask: nil - ) - .transposed(0, 2, 1, 3) - .reshaped(sequenceLength, -1) - - return proj(output) - } - } } // MARK: - Model Configuration Base Classes @@ -350,9 +305,9 @@ extension QwenVLProcessorConfiguration { // MARK: - Common VLM Model Functions -public extension VLMModel where Self: Module, Self: KVCacheDimensionProvider { +extension VLMModel where Self: Module, Self: KVCacheDimensionProvider { /// Common implementation for merging input IDs with image features - func mergeInputIdsWithImageFeatures( + public func mergeInputIdsWithImageFeatures( inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray, @@ -383,7 +338,7 @@ public extension VLMModel where Self: Module, Self: KVCacheDimensionProvider { } /// Helper method to determine if an array is in MLX weight format - func isMLXWeight(_ array: MLXArray) -> Bool { + public func isMLXWeight(_ array: MLXArray) -> Bool { if array.ndim != 4, array.ndim != 5 { return false } @@ -397,7 +352,7 @@ public extension VLMModel where Self: Module, Self: KVCacheDimensionProvider { } /// Helper method to sanitize PyTorch weights for MLX - func sanitizeVisionWeights(weights: [String: MLXArray]) -> [String: MLXArray] { + public func sanitizeVisionWeights(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitizedWeights = [String: MLXArray]() for (k, v) in weights { @@ -437,8 +392,8 @@ public class QwenVLProcessor: UserInputPro } func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) - throws - -> (Int, Int) + throws + -> (Int, Int) { if height < factor { throw VLMError.imageProcessingFailure( @@ -486,7 +441,7 @@ public class QwenVLProcessor: UserInputPro let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) let processedImages = - try images + try images .map { MediaProcessing.inSRGBToneCurveSpace($0) } @@ -595,7 +550,7 @@ public class QwenVLProcessor: UserInputPro } func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) - throws -> [Int] + throws -> [Int] { // Replace single padding token with correct number for each image or video frame let placeholderTokens = try tokenizer.encode(