From a6f552b2c7d657c755c9e8236ec8025e52158e20 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Sat, 8 Mar 2025 10:48:49 +0100 Subject: [PATCH 01/17] Remove development team --- mlx-swift-examples.xcodeproj/project.pbxproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 94c23d75..896e276b 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -1051,7 +1051,7 @@ CURRENT_PROJECT_VERSION = 1; DEBUG_INFORMATION_FORMAT = dwarf; DEVELOPMENT_ASSET_PATHS = "\"Applications/VLMEval/Preview Content\""; - DEVELOPMENT_TEAM = J5CY9Q9UP5; + DEVELOPMENT_TEAM = ""; ENABLE_PREVIEWS = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; @@ -1117,7 +1117,7 @@ CURRENT_PROJECT_VERSION = 1; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEVELOPMENT_ASSET_PATHS = "\"Applications/VLMEval/Preview Content\""; - DEVELOPMENT_TEAM = J5CY9Q9UP5; + DEVELOPMENT_TEAM = ""; ENABLE_NS_ASSERTIONS = NO; ENABLE_PREVIEWS = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; From b6532fa978da5c6fa36c0a5a4861e2636e1a7f5c Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Tue, 4 Mar 2025 19:08:58 +0100 Subject: [PATCH 02/17] Fix typos --- Libraries/MLXVLM/Models/Paligemma.swift | 4 ++-- Libraries/MLXVLM/VLMModelFactory.swift | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift index 76cc89ca..18c94113 100644 --- a/Libraries/MLXVLM/Models/Paligemma.swift +++ b/Libraries/MLXVLM/Models/Paligemma.swift @@ -441,7 +441,7 @@ private enum Vision { /// PaliGemma VLM `UserInputProcessor`. /// /// This is meant to be used with ``PaliGemma`` and is typically created by ``VLMModelFactory``. -public class PaligGemmaProcessor: UserInputProcessor { +public class PaliGemmaProcessor: UserInputProcessor { private let config: PaliGemmaProcessorConfiguration private let tokenizer: any Tokenizer @@ -705,7 +705,7 @@ public struct PaliGemmaConfiguration: Codable, Sendable { } } -/// Configuration for ``PaligGemmaProcessor`` +/// Configuration for ``PaliGemmaProcessor`` public struct PaliGemmaProcessorConfiguration: Codable, Sendable { public struct Size: Codable, Sendable { diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 73f654a9..22bcac69 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -98,7 +98,7 @@ public class ProcessorTypeRegistry: @unchecked Sendable { private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] = [ "PaliGemmaProcessor": create( - PaliGemmaProcessorConfiguration.self, PaligGemmaProcessor.init), + PaliGemmaProcessorConfiguration.self, PaliGemmaProcessor.init), "Qwen2VLProcessor": create( Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), "Idefics3Processor": create( From e2ef119bc6966f3d498da28f5db4a9b1384fe1d9 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Tue, 4 Mar 2025 19:05:59 +0100 Subject: [PATCH 03/17] Add Qwen 2.5 VL --- Applications/VLMEval/ContentView.swift | 2 +- Libraries/MLXVLM/Models/Qwen25VL.swift | 1087 +++++++++++++++++ Libraries/MLXVLM/Models/Qwen2VL.swift | 193 +-- Libraries/MLXVLM/Models/QwenVL.swift | 188 +++ Libraries/MLXVLM/VLMModelFactory.swift | 13 +- .../xcshareddata/swiftpm/Package.resolved | 14 +- 6 files changed, 1311 insertions(+), 186 deletions(-) create mode 100644 Libraries/MLXVLM/Models/Qwen25VL.swift create mode 100644 Libraries/MLXVLM/Models/QwenVL.swift diff --git a/Applications/VLMEval/ContentView.swift b/Applications/VLMEval/ContentView.swift index f7d4ad9b..ed70cfe4 100644 --- a/Applications/VLMEval/ContentView.swift +++ b/Applications/VLMEval/ContentView.swift @@ -322,7 +322,7 @@ class VLMEvaluator { /// This controls which model loads. `qwen2VL2BInstruct4Bit` is one of the smaller ones, so this will fit on /// more devices. - let modelConfiguration = ModelRegistry.qwen2VL2BInstruct4Bit + let modelConfiguration = ModelRegistry.qwen2_5VL3BInstruct4Bit /// parameters controlling the output let generateParameters = MLXLMCommon.GenerateParameters(temperature: 0.6) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift new file mode 100644 index 00000000..6f468579 --- /dev/null +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -0,0 +1,1087 @@ +// Port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/qwen2_5_vl + +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Language + +private enum Language { + + /// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors + static private func applyMultimodalRotaryPositionEmbedding( + q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, + positionIds: MLXArray, mropeSection: [Int] + ) -> (MLXArray, MLXArray) { + var cos = cos[positionIds] + var sin = sin[positionIds] + + cos = + concatenated( + // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] + split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + sin = + concatenated( + split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + // Apply rotary embedding + let qEmbed = (q * cos) + (QwenVL.rotateHalf(q) * sin) + let kEmbed = (k * cos) + (QwenVL.rotateHalf(k) * sin) + return (qEmbed, kEmbed) + } + + fileprivate class Attention: Module { + + let heads: Int + let kvHeads: Int + let headDim: Int + let scale: Float + let mropeSection: [Int] + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + let dim = args.hiddenSize + self.heads = args.attentionHeads + self.kvHeads = args.kvHeads + self.headDim = dim / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { + // mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist() + self.mropeSection = sequence(state: (0, array.makeIterator())) { state in + if let v = state.1.next() { + // note the *2 + state.0 += v * 2 + return state.0 + } else { + return nil + } + }.dropLast() + } else { + fatalError("rope_scaling['mrope_section'] must be an array of integers") + } + + self._rotaryEmbedding.wrappedValue = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + + let offset = cache?.offset ?? 0 + let mask = mask?[0..., 0 ..< keys.dim(-2)] + + queries = rotaryEmbedding(queries, offset: offset) + keys = rotaryEmbedding(keys, offset: offset) + + if let cache { + (keys, values) = cache.update(keys: keys, values: values) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "up_proj") var up: Linear + @ModuleInfo(key: "down_proj") var down: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } + + fileprivate class Qwen25VLDecoderLayer: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return out + } + } + + fileprivate class Qwen25Model: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [Qwen25VLDecoderLayer] + fileprivate let norm: RMSNorm + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + Qwen25VLDecoderLayer(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> MLXArray { + var h: MLXArray + if let inputEmbedding { + h = inputEmbedding + } else if let inputs { + h = embedTokens(inputs) + } else { + fatalError("one of inputs or inputEmbedding must be non-nil") + } + + let mask = createAttentionMask(h: h, cache: cache) + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } + } + + fileprivate class LanguageModel: Module, KVCacheDimensionProvider { + @ModuleInfo var model: Qwen25Model + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + var kvHeads: [Int] + + public init(_ args: Qwen25VLConfiguration.TextConfiguration) { + self.model = Qwen25Model(args) + + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> LMOutput { + var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding) + if let lmHead { + out = lmHead(out) + } else { + out = model.embedTokens.asLinear(out) + } + return LMOutput(logits: out) + } + } +} + +// MARK: - Vision + +private enum Vision { + + static fileprivate func applyMultimodalRotaryPositionEmbedding( + _ tensor: MLXArray, freqs: MLXArray + ) -> MLXArray { + var cos = cos(freqs) + var sin = sin(freqs) + + cos = expandedDimensions(cos, axis: 1) + cos = tiled(cos, repetitions: [1, 1, 2]) + cos = expandedDimensions(cos, axis: 0) + + sin = expandedDimensions(sin, axis: 1) + sin = tiled(sin, repetitions: [1, 1, 2]) + sin = expandedDimensions(sin, axis: 0) + + let output = (tensor * cos) + (QwenVL.rotateHalf(tensor) * sin) + return output.asType(tensor.dtype) + } + + fileprivate class PatchMerger: Module, UnaryLayer { + let hiddenSize: Int + @ModuleInfo(key: "ln_q") var layerNormQ: RMSNorm + @ModuleInfo var mlp: (Linear, GELU, Linear) + + init(dimensions: Int, contextDimensions: Int, spatialMergeSize: Int) { + self.hiddenSize = contextDimensions * (spatialMergeSize * spatialMergeSize) + self._layerNormQ.wrappedValue = RMSNorm(dimensions: contextDimensions, eps: 1e-6) + self.mlp = ( + Linear(hiddenSize, hiddenSize), + GELU(), + Linear(hiddenSize, dimensions) + ) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = layerNormQ(x).reshaped(-1, hiddenSize) + x = mlp.0(x) + x = mlp.1(x) + x = mlp.2(x) + return x + } + } + + fileprivate class Attention: Module { + + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "qkv") var qkv: Linear + @ModuleInfo(key: "proj") var proj: Linear + + public init(dims: Int, numHeads: Int) { + self.numHeads = numHeads + let headDim = dims / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) + self._proj.wrappedValue = Linear(dims, dims) + } + + public func callAsFunction( + _ x: MLXArray, cuSeqlens: MLXArray, rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + let sequenceLength = x.dim(0) + + let qkv = qkv(x) + let s = split(qkv, parts: 3, axis: -1) + var (q, k, v) = (s[0], s[1], s[2]) + + q = q.reshaped(sequenceLength, numHeads, -1) + k = k.reshaped(sequenceLength, numHeads, -1) + v = v.reshaped(sequenceLength, numHeads, -1) + + q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) + k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) + + // Create attention mask + let attentionMask = full( + [1, sequenceLength, sequenceLength], + values: -Float32.greatestFiniteMagnitude) + + // Update mask for each sequence + for i in 1 ..< cuSeqlens.size { + let start = cuSeqlens[i - 1].item(Int.self) + let end = cuSeqlens[i].item(Int.self) + attentionMask[0..., start ..< end, start ..< end] = MLXArray(0) + } + + q = q.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + k = k.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + v = v.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) + + let output = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: attentionMask + ) + .transposed(0, 2, 1, 3) + .reshaped(sequenceLength, -1) + + return proj(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "up_proj") var up: Linear + @ModuleInfo(key: "down_proj") var down: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } + + fileprivate class Qwen25VLVisionBlock: Module { + + @ModuleInfo var norm1: RMSNorm + @ModuleInfo var norm2: RMSNorm + @ModuleInfo(key: "attn") var attention: Attention + @ModuleInfo var mlp: MLP + + public init(_ config: Qwen25VLConfiguration.VisionConfiguration) { + self.norm1 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) + self.norm2 = RMSNorm(dimensions: config.hiddenSize, eps: 1e-6) + + self._attention.wrappedValue = Attention( + dims: config.hiddenSize, numHeads: config.numHeads) + + self.mlp = MLP( + dimensions: config.hiddenSize, hiddenDimensions: config.intermediateSize) + } + + func callAsFunction( + _ hiddenStates: MLXArray, cuSeqlens: MLXArray, rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + var hiddenStates = + hiddenStates + + attention( + norm1(hiddenStates), + cuSeqlens: cuSeqlens, + rotaryPositionEmbedding: rotaryPositionEmbedding + ) + hiddenStates = hiddenStates + mlp(norm2(hiddenStates)) + return hiddenStates + } + } + + fileprivate class VisionModel: Module { + + @ModuleInfo(key: "patch_embed") var patchEmbed: QwenVL.PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: QwenVL.VisionRotaryEmbedding + @ModuleInfo(key: "blocks") var blocks: [Qwen25VLVisionBlock] + @ModuleInfo(key: "merger") var patchMerger: PatchMerger + + let spatialMergeSize: Int + let windowSize: Int + let patchSize: Int + let spatialMergeUnit: Int + let fullattBlockIndexes: [Int] + + public init(_ config: Qwen25VLConfiguration.VisionConfiguration) { + self.spatialMergeSize = config.spatialMergeSize + self.windowSize = config.windowSize + self.patchSize = config.patchSize + self.spatialMergeUnit = config.spatialMergeSize * config.spatialMergeSize + self.fullattBlockIndexes = config.fullattBlockIndexes + + self._patchEmbed.wrappedValue = QwenVL.PatchEmbed( + patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize, + inChannels: config.inChannels, + hiddenSize: config.hiddenSize) + + let headDimensions = config.hiddenSize / config.numHeads + self._rotaryPositionEmbedding.wrappedValue = QwenVL.VisionRotaryEmbedding( + dimensions: headDimensions / 2, theta: 10_000) + + self._blocks.wrappedValue = (0 ..< config.depth).map { _ in + Qwen25VLVisionBlock(config) + } + self._patchMerger.wrappedValue = PatchMerger( + dimensions: config.outHiddenSize, contextDimensions: config.hiddenSize, + spatialMergeSize: config.spatialMergeSize) + } + + func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray { + var positionIds = [MLXArray]() + + for row in frames { + let (t, h, w) = row.values + + var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1) + hposIds = repeated(hposIds, count: w, axis: 1) + hposIds = + hposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + .transposed(0, 2, 1, 3) + .flattened() + + var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0) + wposIds = repeated(wposIds, count: h, axis: 0) + wposIds = + wposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + .transposed(0, 2, 1, 3) + .flattened() + + let stackedPosIds = stacked([hposIds, wposIds], axis: -1) + positionIds.append(tiled(stackedPosIds, repetitions: [t, 1])) + } + + let indices = concatenated(positionIds, axis: 0) + let maxFrameSize = frames.lazy.map { max($0.h, $0.w) }.max() ?? 0 + let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxFrameSize)[ + indices] + + return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) + } + + func getWindowIndex(_ frames: [THW]) -> (MLXArray, MLXArray) { + var windowIndex = [MLXArray]() + var cuWindowSeqlens = [0] + var windowIndexId = 0 + let vitMergerWindowSize = windowSize / spatialMergeSize / patchSize + + for frame in frames { + let (gridT, gridH, gridW) = frame.values + let llmGridH = gridH / spatialMergeSize + let llmGridW = gridW / spatialMergeSize + + let index = MLXArray(0 ..< (gridT * llmGridH * llmGridW)).reshaped( + gridT, llmGridH, llmGridW) + + let padH = vitMergerWindowSize - llmGridH % vitMergerWindowSize + let padW = vitMergerWindowSize - llmGridW % vitMergerWindowSize + let numWindowsH = (llmGridH + padH) / vitMergerWindowSize + let numWindowsW = (llmGridW + padW) / vitMergerWindowSize + + // Pad the index + let indexPadded = padded( + index, + widths: [[0, 0], [0, padH], [0, padW]], + mode: .constant, + value: MLXArray(-100) + ) + + // Reshape and transpose + let indexReshaped = indexPadded.reshaped( + gridT, + numWindowsH, + vitMergerWindowSize, + numWindowsW, + vitMergerWindowSize + ) + + let indexTransposed = indexReshaped.transposed(0, 1, 3, 2, 4).reshaped( + gridT, + numWindowsH * numWindowsW, + vitMergerWindowSize, + vitMergerWindowSize + ) + + // Calculate sequence lengths + let seqlens = sum(indexTransposed .!= -100, axes: [2, 3]).reshaped(-1) + + // Get valid indices + let indexFlattened = indexTransposed.flattened() + let validIndices = indexFlattened.asArray(Int.self).enumerated() + .filter { $0.element != -100 } + .map { $0.offset } + + let validValues = indexFlattened[MLXArray(validIndices)] + + // Add to window index + windowIndex.append(validValues + windowIndexId) + + // Update cumulative sequence lengths + let cuSeqlensTmp = + cumsum(seqlens, axis: 0) * spatialMergeUnit + cuWindowSeqlens.last! + cuWindowSeqlens.append(contentsOf: cuSeqlensTmp.asArray(Int.self)) + + windowIndexId += gridT * llmGridH * llmGridW + } + + // Concatenate all window indices + let combinedWindowIndex = concatenated(windowIndex, axis: 0) + let cuWindowSeqlensArray = MLXArray(cuWindowSeqlens) + + // Get unique values in cuWindowSeqlens + var seen = Set() + var uniqueIndices = [Int]() + + for (i, value) in cuWindowSeqlens.enumerated() { + if !seen.contains(value) { + seen.insert(value) + uniqueIndices.append(i) + } + } + + let uniqueCuWindowSeqlens = cuWindowSeqlensArray[MLXArray(uniqueIndices)] + + return (combinedWindowIndex, uniqueCuWindowSeqlens) + } + + public func callAsFunction(_ hiddenStates: MLXArray, frames: [THW]) -> MLXArray { + var hiddenStates = patchEmbed(hiddenStates) + let rotaryPosEmb = rotaryPositionEmbedding(frames) + + // Get window indices and sequence lengths + let (windowIndex, cuWindowSeqlens) = getWindowIndex(frames) + + // Reshape and reindex hidden states + let seqLen = hiddenStates.dim(0) + hiddenStates = hiddenStates.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) + hiddenStates = hiddenStates[windowIndex, 0..., 0...] + hiddenStates = hiddenStates.reshaped(seqLen, -1) + + // Reshape and reindex rotary position embeddings + var rotaryPosEmbReshaped = rotaryPosEmb.reshaped( + seqLen / spatialMergeUnit, spatialMergeUnit, -1) + rotaryPosEmbReshaped = rotaryPosEmbReshaped[windowIndex, 0..., 0...] + rotaryPosEmbReshaped = rotaryPosEmbReshaped.reshaped(seqLen, -1) + + // Calculate cumulative sequence lengths for full attention + var cuSeqlens = [0] + for frame in frames { + let seqLen = frame.h * frame.w + cuSeqlens.append( + contentsOf: Array(repeating: seqLen, count: frame.t).map { + cuSeqlens.last! + $0 + }) + } + let cuSeqlensArray = MLXArray(cuSeqlens) + + // Process through blocks + for (i, block) in blocks.enumerated() { + // Use full attention for specific blocks, window attention for others + let cuSeqlensNow = + fullattBlockIndexes.contains(i) ? cuSeqlensArray : cuWindowSeqlens + + hiddenStates = block( + hiddenStates, + cuSeqlens: cuSeqlensNow, + rotaryPositionEmbedding: rotaryPosEmbReshaped + ) + } + + // Apply patch merger + hiddenStates = patchMerger(hiddenStates) + + // Reorder back to original sequence + let reverseIndices = argSort(windowIndex, axis: 0) + hiddenStates = hiddenStates[reverseIndices, 0...] + + return hiddenStates + } + + private func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4, array.ndim != 5 { + return false + } + + if array.dim(-1) == 3 { + return true + } + + let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) + return outChannels >= kH && outChannels >= kW && kH == kW + } + + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitizedWeights = [String: MLXArray]() + + for (k, v) in weights { + if k.contains("position_id") { + // Remove unused position_ids + continue + } else if k.contains("patch_embed.proj.weight") { + // PyTorch conv2d weight tensors have shape: + // [B, out_channels, in_channels, kH, KW] + // MLX conv2d expects the weight be of shape: + // [B, out_channels, kH, KW, in_channels] + if isMLXWeight(v) { + sanitizedWeights[k] = v + } else { + sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1) + } + } else { + sanitizedWeights[k] = v + } + } + + return sanitizedWeights + } + } +} + +// MARK: - Processor + +/// Qwen2.5VL VLM `UserInputProcessor`. +/// +/// This is meant to be used with ``Qwen25VL`` and is typically created by ``VLMModelFactory``. +public class Qwen25VLProcessor: UserInputProcessor { + private let config: Qwen25VLProcessorConfiguration + private let tokenizer: any Tokenizer + + public init(_ config: Qwen25VLProcessorConfiguration, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( + MLXArray, THW + ) { + // first apply the user requested resizing, etc. if any + let images = images.map { MediaProcessing.apply($0, processing: processing) } + + // image_processing_qwen2_vl._preprocess + + let size = images[0].extent.size + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) + let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + + let processedImages = + try images + .map { + MediaProcessing.inSRGBToneCurveSpace($0) + } + .map { + return MediaProcessing.resampleBicubic($0, to: resizedSize) + } + .map { + MediaProcessing.normalize( + $0, mean: config.imageMeanTuple, std: config.imageStdTuple) + } + .map { + MediaProcessing.asMLXArray($0) + } + + var patches = concatenated(processedImages) + let mod = patches.dim(0) % config.temporalPatchSize + if mod != 0 { + let lastPatch = patches[-1, .ellipsis] + let lastPatchRepeated = tiled( + lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) + patches = concatenated([patches, lastPatchRepeated]) + } + let channel = patches.dim(1) + let gridT = patches.dim(0) / self.config.temporalPatchSize + let gridH = resizedHeight / self.config.patchSize + let gridW = resizedWidth / self.config.patchSize + + patches = patches.reshaped( + gridT, + config.temporalPatchSize, + channel, + gridH / config.mergeSize, + config.mergeSize, + config.patchSize, + gridW / config.mergeSize, + config.mergeSize, + config.patchSize + ) + patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) + + let flattenedPatches = patches.reshaped( + gridT * gridH * gridW, + channel * config.temporalPatchSize * config.patchSize * config.patchSize + ) + + return (flattenedPatches, .init(gridT, gridH, gridW)) + } + + public func prepare(input: UserInput) async throws -> LMInput { + let messages = input.prompt.asMessages() + var promptTokens = try tokenizer.applyChatTemplate(messages: messages) + + // Text-only input + if input.images.isEmpty, input.videos.isEmpty { + return LMInput(tokens: MLXArray(promptTokens)) + } + + // Process images if any + var processedImage: LMInput.ProcessedImage? + if !input.images.isEmpty { + let imagePixelsAndFrames = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } + let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) + processedImage = LMInput.ProcessedImage( + pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) + if let imageFrames = processedImage?.frames { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } + } + + // Process videos if any + var processedVideo: LMInput.ProcessedVideo? + if !input.videos.isEmpty { + var videosAsImageSequences = [[CIImage]]() + for video in input.videos { + if let imageSequence = try? await MediaProcessing.asCIImageSequence( + video.asAVAsset(), samplesPerSecond: 2) + { + videosAsImageSequences.append(imageSequence) + } + } + let videoPixelsAndFrames = try videosAsImageSequences.map { + try preprocess(images: $0, processing: input.processing) + } + let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) + processedVideo = LMInput.ProcessedVideo( + pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) + if let videoFrames = processedVideo?.frames { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } + } + + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + return LMInput( + text: .init(tokens: promptArray, mask: mask), + image: processedImage, + video: processedVideo) + } +} + +// MARK: - Model + +/// Qwen2.5VL VLM +/// +/// This is typically created by ``VLMModelFactory``. +public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { + + @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel + @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel + + public let config: Qwen25VLConfiguration + + public var vocabularySize: Int { config.baseConfiguration.vocabularySize } + public var kvHeads: [Int] { languageModel.kvHeads } + + public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers { + languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } + + public init(_ config: Qwen25VLConfiguration) { + self.config = config + self._visionModel.wrappedValue = Vision.VisionModel(config.visionConfiguration) + self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) + } + + private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?) + -> MLXArray + { + guard let pixelValues, let frames else { + return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis]) + } + + // Get the input embeddings from the language model + let inputEmbeds = languageModel.model.embedTokens(inputIds) + + // Get the ouptut hidden states from the vision model + var hiddenStates = self.visionModel(pixelValues, frames: frames) + + if hiddenStates.ndim == 2 { + hiddenStates = hiddenStates[.newAxis, 0..., 0...] + } + + // Insert special image tokens in the input_ids + return QwenVL.mergeInputIdsWithImageFeatures( + inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates, + imageTokenId: config.baseConfiguration.imageTokenId, + videoTokenId: config.baseConfiguration.videoTokenId) + } + + public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws + -> PrepareResult + { + let dtype = visionModel.patchEmbed.proj.weight.dtype + + // Process both images and videos together + var allPixels: MLXArray? + var allFrames: [THW] = [] + + if let imagePixels = input.image?.pixels, let imageFrames = input.image?.frames { + allPixels = imagePixels.asType(dtype) + allFrames.append(contentsOf: imageFrames) + } + + if let videoPixels = input.video?.pixels, let videoFrames = input.video?.frames { + if allPixels == nil { + allPixels = videoPixels.asType(dtype) + } else { + allPixels = concatenated([allPixels!, videoPixels.asType(dtype)]) + } + allFrames.append(contentsOf: videoFrames) + } + + let inputEmbeddings = self.inputEmbeddings( + inputIds: input.text.tokens, pixelValues: allPixels, + frames: allFrames.isEmpty ? nil : allFrames) + + let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) + + return .logits(result) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { + languageModel(inputs, cache: cache).logits + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + visionModel.sanitize( + weights: + Dictionary( + uniqueKeysWithValues: weights.map { key, value in + var key = key + if !key.contains("vision_tower") { + key = key.replacingOccurrences(of: "visual", with: "vision_tower") + } + if !key.contains("language_model") { + key = key.replacingOccurrences( + of: "model", with: "language_model.model") + key = key.replacingOccurrences( + of: "lm_head", with: "language_model.lm_head") + } + + return (key, value) + }) + ) + } +} + +// MARK: - Configuration + +/// Configuration for ``Qwen25VL`` +public struct Qwen25VLConfiguration: Codable, Sendable { + + public struct TextConfiguration: Codable, Sendable { + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + private let _rmsNormEps: Float? + public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } + public let vocabularySize: Int + public let kvHeads: Int + private let _maxPositionEmbeddings: Int? + public var maxPositionEmbeddings: Int { _maxPositionEmbeddings ?? 128000 } + private let _ropeTheta: Float? + public var ropeTheta: Float { _ropeTheta ?? 1_000_000 } + private let _ropeTraditional: Bool? + public var ropeTraditional: Bool { _ropeTraditional ?? false } + public let ropeScaling: [String: StringOrNumber]? + private let _tieWordEmbeddings: Bool? + public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true } + private let _slidingWindow: Int? + public var slidingWindow: Int { _slidingWindow ?? 32768 } + private let _useSlidingWindow: Bool? + public var useSlidingWindow: Bool { _useSlidingWindow ?? false } + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case _rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case _maxPositionEmbeddings = "max_position_embeddings" + case _ropeTheta = "rope_theta" + case _ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case _tieWordEmbeddings = "tie_word_embeddings" + case _slidingWindow = "sliding_window" + case _useSlidingWindow = "use_sliding_window" + } + } + + public struct VisionConfiguration: Codable, Sendable { + public let depth: Int + public let hiddenSize: Int + public let intermediateSize: Int + public let outHiddenSize: Int + public let numHeads: Int + public let patchSize: Int + private let _inChans: Int? + public var inChannels: Int { _inChans ?? 3 } + private let _layerNormEps: Float? + public var layerNormEps: Float { _layerNormEps ?? 1e-6 } + public let spatialPatchSize: Int + public let spatialMergeSize: Int + public let temporalPatchSize: Int + public let windowSize: Int + public let fullattBlockIndexes: [Int] + public let tokensPerSecond: Int + private let _skipVision: Bool? + public var skipVision: Bool { _skipVision ?? false } + private let _hiddenAct: String? + public var hiddenAct: String { _hiddenAct ?? "silu" } + + enum CodingKeys: String, CodingKey { + case depth + case hiddenSize = "hidden_size" + case intermediateSize = "intermediate_size" + case outHiddenSize = "out_hidden_size" + case numHeads = "num_heads" + case patchSize = "patch_size" + case _inChans = "in_chans" + case _layerNormEps = "layer_norm_eps" // Added this line + case spatialPatchSize = "spatial_patch_size" + case spatialMergeSize = "spatial_merge_size" + case temporalPatchSize = "temporal_patch_size" + case windowSize = "window_size" + case fullattBlockIndexes = "fullatt_block_indexes" + case tokensPerSecond = "tokens_per_second" + case _skipVision = "skip_vision" + case _hiddenAct = "hidden_act" + } + } + + public struct BaseConfiguration: Codable, Sendable { + public let modelType: String + public let vocabularySize: Int + public let imageTokenId: Int + public let videoTokenId: Int + public let visionStartTokenId: Int + public let visionEndTokenId: Int + public let visionTokenId: Int + public let hiddenSize: Int + public let numAttentionHeads: Int + public let numHiddenLayers: Int + public let intermediateSize: Int + public let numKeyValueHeads: Int + public let slidingWindow: Int + public let useSlidingWindow: Bool + public let maxWindowLayers: Int + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case vocabularySize = "vocab_size" + case imageTokenId = "image_token_id" + case videoTokenId = "video_token_id" + case visionStartTokenId = "vision_start_token_id" + case visionEndTokenId = "vision_end_token_id" + case visionTokenId = "vision_token_id" + case hiddenSize = "hidden_size" + case numAttentionHeads = "num_attention_heads" + case numHiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case numKeyValueHeads = "num_key_value_heads" + case slidingWindow = "sliding_window" + case useSlidingWindow = "use_sliding_window" + case maxWindowLayers = "max_window_layers" + } + } + + public let textConfiguration: TextConfiguration + public let visionConfiguration: VisionConfiguration + public let baseConfiguration: BaseConfiguration + + enum CodingKeys: String, CodingKey { + case visionConfiguration = "vision_config" + } + + public init(from decoder: any Swift.Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + // this is a sub-dictionary + self.visionConfiguration = try container.decode( + VisionConfiguration.self, forKey: .visionConfiguration) + + // these are overlaid in the top level + self.textConfiguration = try TextConfiguration(from: decoder) + self.baseConfiguration = try BaseConfiguration(from: decoder) + } +} + +/// Configuration for ``Qwen25VLProcessor`` +public struct Qwen25VLProcessorConfiguration: Codable, Sendable { + public struct Size: Codable, Sendable { + public let maxPixels: Int + public let minPixels: Int + + enum CodingKeys: String, CodingKey { + case maxPixels = "max_pixels" + case minPixels = "min_pixels" + } + } + + public let imageMean: [CGFloat] + public let imageStd: [CGFloat] + public let minPixels: Int + public let maxPixels: Int + public let mergeSize: Int + public let patchSize: Int + public let temporalPatchSize: Int + public let imageProcessorType: String + + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } + + public var size: Size { + Size(maxPixels: maxPixels, minPixels: minPixels) + } + + enum CodingKeys: String, CodingKey { + case imageMean = "image_mean" + case imageStd = "image_std" + case minPixels = "min_pixels" + case maxPixels = "max_pixels" + case mergeSize = "merge_size" + case patchSize = "patch_size" + case temporalPatchSize = "temporal_patch_size" + case imageProcessorType = "image_processor_type" + } +} diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index f71e2352..7c4c9ac9 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -11,16 +11,6 @@ import MLXLMCommon import MLXNN import Tokenizers -// MARK: - Common - -/// Rotates half the hidden dims of the input -private func rotateHalf(_ x: MLXArray) -> MLXArray { - let index = x.dim(-1) / 2 - let x1 = x[.ellipsis, 0 ..< index] - let x2 = x[.ellipsis, index...] - return concatenated([-x2, x1], axis: -1) -} - // MARK: - Language private enum Language { @@ -47,8 +37,8 @@ private enum Language { )[0..., .newAxis, 0..., 0...] // Apply rotary embedding - let qEmbed = (q * cos) + (rotateHalf(q) * sin) - let kEmbed = (k * cos) + (rotateHalf(k) * sin) + let qEmbed = (q * cos) + (QwenVL.rotateHalf(q) * sin) + let kEmbed = (k * cos) + (QwenVL.rotateHalf(k) * sin) return (qEmbed, kEmbed) } @@ -267,64 +257,10 @@ private enum Vision { sin = tiled(sin, repetitions: [1, 1, 2]) sin = expandedDimensions(sin, axis: 0) - let output = (tensor * cos) + (rotateHalf(tensor) * sin) + let output = (tensor * cos) + (QwenVL.rotateHalf(tensor) * sin) return output.asType(tensor.dtype) } - fileprivate class VisionRotaryEmbedding { - let dimensions: Int - let theta: Float - let inverseFreq: MLXArray - - init(dimensions: Int, theta: Float) { - self.dimensions = dimensions - self.theta = theta - let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions - self.inverseFreq = 1.0 / pow(theta, p) - } - - func callAsFunction(sequenceLength: Int) -> MLXArray { - let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) - let freqs = outer(seq, inverseFreq) - return freqs - } - } - - fileprivate class PatchEmbed: Module, UnaryLayer { - @ModuleInfo var proj: Conv3d - - let patchSize: Int - let temporalPatchSize: Int - let inChannels: Int - let embedDimensions: Int - - init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, embedDimensions: Int) { - self.patchSize = patchSize - self.temporalPatchSize = temporalPatchSize - self.inChannels = inChannels - self.embedDimensions = embedDimensions - - let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) - self._proj.wrappedValue = Conv3d( - inputChannels: inChannels, - outputChannels: embedDimensions, - kernelSize: kernelSize, - stride: kernelSize, - bias: false - ) - } - - func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { - var hiddenStates = hiddenStates.reshaped( - -1, inChannels, temporalPatchSize, patchSize, patchSize - ).movedAxis(source: 1, destination: 4) - - hiddenStates = proj(hiddenStates) - hiddenStates = hiddenStates.reshaped(-1, embedDimensions) - return hiddenStates - } - } - fileprivate class PatchMerger: Module, UnaryLayer { let hiddenSize: Int @ModuleInfo(key: "ln_q") var layerNormQ: LayerNorm @@ -451,8 +387,8 @@ private enum Vision { fileprivate class VisionModel: Module { - @ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed - @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding + @ModuleInfo(key: "patch_embed") var patchEmbed: QwenVL.PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: QwenVL.VisionRotaryEmbedding @ModuleInfo(key: "blocks") var blocks: [Qwen2VLVisionBlock] @ModuleInfo(key: "merger") var patchMerger: PatchMerger @@ -461,14 +397,14 @@ private enum Vision { public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { self.spatialMergeSize = config.spatialMergeSize - self._patchEmbed.wrappedValue = PatchEmbed( + self._patchEmbed.wrappedValue = QwenVL.PatchEmbed( patchSize: config.patchSize, temporalPatchSize: config.temporalPatchSize, inChannels: config.inChannels, embedDimensions: config.embedDimensions) let headDimensions = config.embedDimensions / config.numHeads - self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding( + self._rotaryPositionEmbedding.wrappedValue = QwenVL.VisionRotaryEmbedding( dimensions: headDimensions / 2, theta: 10_000) self._blocks.wrappedValue = (0 ..< config.depth).map { _ in @@ -592,38 +528,6 @@ public class Qwen2VLProcessor: UserInputProcessor { self.tokenizer = tokenizer } - // image_processing_qwen2_vl.smart_resize - private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) - throws -> (Int, Int) - { - if height < factor { - throw VLMError.imageProcessingFailure( - "height: \(height) must be larger than factor: \(factor)") - } - if width < factor { - throw VLMError.imageProcessingFailure( - "width: \(width) must be larger than factor: \(factor)") - } - if max(height, width) / min(height, width) > 200 { - throw VLMError.imageProcessingFailure( - "absolute aspect ratio must be smaller than 200: \(width)x\(height)") - } - - var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) - var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) - - if hBar * wBar > maxPixels { - let beta = sqrt(Float(height * width) / Float(maxPixels)) - hBar = Int(floor(Float(height) / beta / Float(factor))) * factor - wBar = Int(floor(Float(width) / beta / Float(factor))) * factor - } else if hBar * wBar < minPixels { - let beta = sqrt(Float(minPixels) / Float(height * width)) - hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor - wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor - } - return (hBar, wBar) - } - public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( MLXArray, THW ) { @@ -633,7 +537,7 @@ public class Qwen2VLProcessor: UserInputProcessor { // image_processing_qwen2_vl._preprocess let size = images[0].extent.size - let (resizedHeight, resizedWidth) = try targetSize( + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( height: Int(size.height), width: Int(size.width), factor: config.patchSize * config.mergeSize, minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) @@ -708,8 +612,9 @@ public class Qwen2VLProcessor: UserInputProcessor { processedImage = LMInput.ProcessedImage( pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) if let imageFrames = processedImage?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>") + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } @@ -731,8 +636,9 @@ public class Qwen2VLProcessor: UserInputProcessor { processedVideo = LMInput.ProcessedVideo( pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) if let videoFrames = processedVideo?.frames { - promptTokens = try replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>") + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } @@ -743,42 +649,6 @@ public class Qwen2VLProcessor: UserInputProcessor { image: processedImage, video: processedVideo) } - - func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) - throws -> [Int] - { - // Replace single padding token with correct number for each image or video frame - let placeholderTokens = try tokenizer.encode( - text: "<|vision_start|>\(paddingToken)<|vision_end|>") - let placeholderRanges = promptTokens.ranges(of: placeholderTokens) - guard placeholderRanges.count == frames.count else { - throw VLMError.processing( - "Number of placeholder tokens does not match number of frames") - } - let mergeLength = config.mergeSize * config.mergeSize - let replacementSequences = try frames.map { frame in - let paddingCount = frame.product / mergeLength - return try tokenizer.encode( - text: - "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" - ) - } - // Build the final array - var result: [Int] = [] - var currentIndex = promptTokens.startIndex - for (range, replacement) in zip(placeholderRanges, replacementSequences) { - // Add tokens before the placeholder - result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) - // Add replacement sequence - result.append(contentsOf: replacement) - currentIndex = range.upperBound - } - // Add any remaining tokens after the last replacement - if currentIndex < promptTokens.endIndex { - result.append(contentsOf: promptTokens[currentIndex...]) - } - return result - } } // MARK: - Model @@ -824,37 +694,10 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { } // Insert special image tokens in the input_ids - return mergeInputIdsWithImageFeatures( - inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates) - } - - private func mergeInputIdsWithImageFeatures( - inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray - ) -> MLXArray { - let imageTokenIndex = config.baseConfiguration.imageTokenId - let videoTokenIndex = config.baseConfiguration.videoTokenId - - var imageIndices = [Int]() - for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == imageTokenIndex || v == videoTokenIndex { - imageIndices.append(i) - } - } - - // Make sure shapes match before assignment - var result = inputEmbeds - if result.ndim == 2 { - result = result[.newAxis, 0..., 0...] - } - - if imageFeatures.ndim == 2 { - let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] - result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures - } else { - result[0..., MLXArray(imageIndices), 0...] = imageFeatures - } - - return result + return QwenVL.mergeInputIdsWithImageFeatures( + inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates, + imageTokenId: config.baseConfiguration.imageTokenId, + videoTokenId: config.baseConfiguration.videoTokenId) } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift new file mode 100644 index 00000000..b76ff7eb --- /dev/null +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -0,0 +1,188 @@ +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Common Utilities for Qwen 2 VL and Qwen 2.5 VL + +public struct QwenVL { + /// Rotates half the hidden dims of the input + static func rotateHalf(_ x: MLXArray) -> MLXArray { + let index = x.dim(-1) / 2 + let x1 = x[.ellipsis, 0 ..< index] + let x2 = x[.ellipsis, index...] + return concatenated([-x2, x1], axis: -1) + } + + static func mergeInputIdsWithImageFeatures( + inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray, + imageTokenId: Int, videoTokenId: Int + ) -> MLXArray { + var imageIndices = [Int]() + for (i, v) in inputIds.asArray(Int.self).enumerated() { + if v == imageTokenId || v == videoTokenId { + imageIndices.append(i) + } + } + + // Make sure shapes match before assignment + var result = inputEmbeds + if result.ndim == 2 { + result = result[.newAxis, 0..., 0...] + } + + if imageFeatures.ndim == 2 { + let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] + result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures + } else { + result[0..., MLXArray(imageIndices), 0...] = imageFeatures + } + + return result + } + + public class VisionRotaryEmbedding { + let dimensions: Int + let theta: Float + let inverseFreq: MLXArray + + init(dimensions: Int, theta: Float) { + self.dimensions = dimensions + self.theta = theta + let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions + self.inverseFreq = 1.0 / pow(theta, p) + } + + func callAsFunction(sequenceLength: Int) -> MLXArray { + let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) + let freqs = outer(seq, inverseFreq) + return freqs + } + } + + public class PatchEmbed: Module, UnaryLayer { + @ModuleInfo var proj: Conv3d + + let patchSize: Int + let temporalPatchSize: Int + let inChannels: Int + let outputDimensions: Int + + // For Qwen 2 VL + convenience init( + patchSize: Int, temporalPatchSize: Int, inChannels: Int, embedDimensions: Int + ) { + self.init( + patchSize: patchSize, temporalPatchSize: temporalPatchSize, + inChannels: inChannels, outputDimensions: embedDimensions) + } + + // For Qwen 2.5 VL + convenience init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, hiddenSize: Int) { + self.init( + patchSize: patchSize, temporalPatchSize: temporalPatchSize, + inChannels: inChannels, outputDimensions: hiddenSize) + } + + // Common initializer + init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, outputDimensions: Int) { + self.patchSize = patchSize + self.temporalPatchSize = temporalPatchSize + self.inChannels = inChannels + self.outputDimensions = outputDimensions + + let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) + self._proj.wrappedValue = Conv3d( + inputChannels: inChannels, + outputChannels: outputDimensions, + kernelSize: kernelSize, + stride: kernelSize, + bias: false + ) + } + + public func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { + var hiddenStates = hiddenStates.reshaped( + -1, inChannels, temporalPatchSize, patchSize, patchSize + ).movedAxis(source: 1, destination: 4) + + hiddenStates = proj(hiddenStates) + hiddenStates = hiddenStates.reshaped(-1, outputDimensions) + return hiddenStates + } + } + + // image_processing_qwen2_vl.smart_resize + static func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) + throws + -> (Int, Int) + { + if height < factor { + throw VLMError.imageProcessingFailure( + "height: \(height) must be larger than factor: \(factor)") + } + if width < factor { + throw VLMError.imageProcessingFailure( + "width: \(width) must be larger than factor: \(factor)") + } + if max(height, width) / min(height, width) > 200 { + throw VLMError.imageProcessingFailure( + "absolute aspect ratio must be smaller than 200: \(width)x\(height)") + } + + var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) + var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) + + if hBar * wBar > maxPixels { + let beta = sqrt(Float(height * width) / Float(maxPixels)) + hBar = Int(floor(Float(height) / beta / Float(factor))) * factor + wBar = Int(floor(Float(width) / beta / Float(factor))) * factor + } else if hBar * wBar < minPixels { + let beta = sqrt(Float(minPixels) / Float(height * width)) + hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor + wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor + } + return (hBar, wBar) + } + + static func replacePaddingTokens( + in promptTokens: [Int], frames: [THW], paddingToken: String, mergeSize: Int, + tokenizer: any Tokenizer + ) throws -> [Int] { + // Replace single padding token with correct number for each image or video frame + let placeholderTokens = try tokenizer.encode( + text: "<|vision_start|>\(paddingToken)<|vision_end|>") + let placeholderRanges = promptTokens.ranges(of: placeholderTokens) + guard placeholderRanges.count == frames.count else { + throw VLMError.processing( + "Number of placeholder tokens does not match number of frames") + } + let mergeLength = mergeSize * mergeSize + let replacementSequences = try frames.map { frame in + let paddingCount = frame.product / mergeLength + return try tokenizer.encode( + text: + "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" + ) + } + // Build the final array + var result: [Int] = [] + var currentIndex = promptTokens.startIndex + for (range, replacement) in zip(placeholderRanges, replacementSequences) { + // Add tokens before the placeholder + result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) + // Add replacement sequence + result.append(contentsOf: replacement) + currentIndex = range.upperBound + } + // Add any remaining tokens after the last replacement + if currentIndex < promptTokens.endIndex { + result.append(contentsOf: promptTokens[currentIndex...]) + } + return result + } +} diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 22bcac69..f05d185e 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -60,6 +60,7 @@ public class ModelTypeRegistry: @unchecked Sendable { private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), + "qwen2_5_vl": create(Qwen25VLConfiguration.self, Qwen25VL.init), "idefics3": create(Idefics3Configuration.self, Idefics3.init), ] @@ -85,7 +86,6 @@ public class ModelTypeRegistry: @unchecked Sendable { } return try creator(configuration) } - } public class ProcessorTypeRegistry: @unchecked Sendable { @@ -101,6 +101,8 @@ public class ProcessorTypeRegistry: @unchecked Sendable { PaliGemmaProcessorConfiguration.self, PaliGemmaProcessor.init), "Qwen2VLProcessor": create( Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), + "Qwen2_5_VLProcessor": create( + Qwen25VLProcessorConfiguration.self, Qwen25VLProcessor.init), "Idefics3Processor": create( Idefics3ProcessorConfiguration.self, Idefics3Processor.init), ] @@ -130,7 +132,6 @@ public class ProcessorTypeRegistry: @unchecked Sendable { } return try creator(configuration, tokenizer) } - } /// Registry of models and any overrides that go with them, e.g. prompt augmentation. @@ -157,15 +158,21 @@ public class ModelRegistry: @unchecked Sendable { defaultPrompt: "Describe the image in English" ) + static public let qwen2_5VL3BInstruct4Bit = ModelConfiguration( + id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit", + defaultPrompt: "Describe the image in English" + ) + static public let smolvlminstruct4bit = ModelConfiguration( id: "mlx-community/SmolVLM-Instruct-4bit", defaultPrompt: "Describe the image in English" ) - static private func all() -> [ModelConfiguration] { + static public func all() -> [ModelConfiguration] { [ paligemma3bMix448_8bit, qwen2VL2BInstruct4Bit, + qwen2_5VL3BInstruct4Bit, ] } diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 3ed84667..51da6faa 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c", + "originHash" : "327a4376ec20e25f941929e0bd2eefea67914f3c98414e5489f49c7e49eab7ab", "pins" : [ { "identity" : "gzipswift", @@ -24,8 +24,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", - "version" : "0.21.2" + "revision" : "b990c58153af70eb0914bca7dd74401d341fa9ae", + "version" : "0.21.3" } }, { @@ -87,8 +87,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-numerics", "state" : { - "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", - "version" : "1.0.2" + "revision" : "e0ec0f5f3af6f3e4d5e7a19d2af26b481acb6ba8", + "version" : "1.0.3" } }, { @@ -96,8 +96,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "55710ddfb1ae804b4b7ce973be75cf2e41272185", - "version" : "0.1.17" + "revision" : "be855fac725dbae27264e47a3eb535cc422a4ba8", + "version" : "0.1.18" } } ], From 788fffaea7662c9392fb57531253fd8b1d77e804 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Thu, 6 Mar 2025 08:06:41 +0100 Subject: [PATCH 04/17] Fix media downsampling --- Libraries/MLXVLM/MediaProcessing.swift | 28 +++++++++++++------------ Libraries/MLXVLM/Models/QwenVL.swift | 29 +++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index 80bbbf26..63da46f0 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -15,7 +15,7 @@ private let context = CIContext() /// var image: CIImage /// image = MediaProcessing.inSRGBToneCurveSpace(image) /// -/// // apply user instructions +/// // Apply user instructions /// image = MediaProcessing.apply(image, processing: processing) /// /// image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) @@ -59,6 +59,10 @@ public enum MediaProcessing { } /// Resample the image using bicubic interpolation. + /// - Parameters: + /// - image: The image to resample + /// - size: The target size + /// - Returns: The resampled image static public func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { let filter = CIFilter.bicubicScaleTransform() let extent = image.extent.size @@ -70,19 +74,13 @@ public enum MediaProcessing { let desiredAspectRatio = size.width / size.height filter.aspectRatio = Float(1 / inputAspectRatio * desiredAspectRatio) - // that image is now the aspect ratio of the target and the size - // of the shorter dimension - let scale: CGFloat - if extent.width < extent.height { - scale = size.width / extent.width - } else { - scale = size.height / extent.height - } + // Use the same scaling approach regardless of orientation + let scale = min(size.width / extent.width, size.height / extent.height) filter.scale = Float(scale) let rescaled = filter.outputImage! - // the image has a DoD larger than the requested size so crop + // The image has a DoD larger than the requested size, so crop // it to the desired size return rescaled.cropped(to: CGRect(origin: .zero, size: size)) } @@ -94,7 +92,7 @@ public enum MediaProcessing { let filter = CIFilter.colorMatrix() filter.inputImage = image - // this should match + // This should match // https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html // // output[channel] = (input[channel] - mean[channel]) / std[channel] @@ -113,6 +111,10 @@ public enum MediaProcessing { } /// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]` + /// - Parameters: + /// - image: The image to convert + /// - colorSpace: Optional color space for rendering + /// - Returns: The MLXArray representation of the image static public func asMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray { let size = image.extent.size let w = Int(size.width.rounded()) @@ -135,10 +137,10 @@ public enum MediaProcessing { var array = MLXArray(data, [h, w, 4], type: Float32.self) - // drop 4th channel + // Drop 4th channel array = array[0..., 0..., ..<3] - // convert to 1, C, H, W + // Convert to 1, C, H, W array = array.reshaped(1, h, w, 3).transposed(0, 3, 1, 2) return array diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift index b76ff7eb..a396893a 100644 --- a/Libraries/MLXVLM/Models/QwenVL.swift +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -123,20 +123,25 @@ public struct QwenVL { { if height < factor { throw VLMError.imageProcessingFailure( - "height: \(height) must be larger than factor: \(factor)") + "Height: \(height) must be larger than factor: \(factor)") } if width < factor { throw VLMError.imageProcessingFailure( - "width: \(width) must be larger than factor: \(factor)") + "Width: \(width) must be larger than factor: \(factor)") } if max(height, width) / min(height, width) > 200 { throw VLMError.imageProcessingFailure( - "absolute aspect ratio must be smaller than 200: \(width)x\(height)") + "Absolute aspect ratio must be smaller than 200: \(width) × \(height)") } + // Maximum allowed dimension for any single side to prevent buffer overflows + // This is important for portrait/landscape images with extreme aspect ratios + let maxDimension = 224 + var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) + // Start by scaling based on total pixel count if hBar * wBar > maxPixels { let beta = sqrt(Float(height * width) / Float(maxPixels)) hBar = Int(floor(Float(height) / beta / Float(factor))) * factor @@ -146,6 +151,24 @@ public struct QwenVL { hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor } + + // Additionally check if either dimension exceeds the maximum allowed + if hBar > maxDimension { + // Calculate how much we need to scale down height + let scale = Float(maxDimension) / Float(hBar) + // Apply that scale to both dimensions to maintain aspect ratio + hBar = Int(round(Float(hBar) * scale / Float(factor))) * factor + wBar = Int(round(Float(wBar) * scale / Float(factor))) * factor + } + + if wBar > maxDimension { + // Calculate how much we need to scale down width + let scale = Float(maxDimension) / Float(wBar) + // Apply that scale to both dimensions to maintain aspect ratio + hBar = Int(round(Float(hBar) * scale / Float(factor))) * factor + wBar = Int(round(Float(wBar) * scale / Float(factor))) * factor + } + return (hBar, wBar) } From 93565497303d0c75978d5372af6bf9eb2b1138f3 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Sat, 8 Mar 2025 10:05:40 +0100 Subject: [PATCH 05/17] More media downsampling fixes --- Libraries/MLXVLM/MediaProcessing.swift | 54 +++++++++---- Libraries/MLXVLM/Models/Qwen25VL.swift | 103 ++++++++++++++++++------- Libraries/MLXVLM/Models/Qwen2VL.swift | 103 ++++++++++++++++++------- Libraries/MLXVLM/Models/QwenVL.swift | 10 +++ Libraries/MLXVLM/VLMModelFactory.swift | 17 +++- 5 files changed, 216 insertions(+), 71 deletions(-) diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index 63da46f0..4f3f94cb 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -63,26 +63,48 @@ public enum MediaProcessing { /// - image: The image to resample /// - size: The target size /// - Returns: The resampled image - static public func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { - let filter = CIFilter.bicubicScaleTransform() - let extent = image.extent.size - - filter.inputImage = image + public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { + // First, create a CIFilter for precise resampling + guard let filter = CIFilter(name: "CILanczosScaleTransform") else { + // Fall back to affine transform if filter isn't available + let scaleX = size.width / image.extent.width + let scaleY = size.height / image.extent.height + let transform = CGAffineTransform(scaleX: scaleX, y: scaleY) + let scaled = image.transformed(by: transform) + + // Force exact dimensions by cropping + return scaled.cropped(to: CGRect(origin: .zero, size: size)) + } - // set the aspect ratio to match the aspect ratio of the target - let inputAspectRatio = extent.width / extent.height - let desiredAspectRatio = size.width / size.height - filter.aspectRatio = Float(1 / inputAspectRatio * desiredAspectRatio) + filter.setValue(image, forKey: kCIInputImageKey) + filter.setValue(size.width / image.extent.width, forKey: kCIInputScaleKey) + filter.setValue(1.0, forKey: kCIInputAspectRatioKey) - // Use the same scaling approach regardless of orientation - let scale = min(size.width / extent.width, size.height / extent.height) - filter.scale = Float(scale) + guard let scaledImage = filter.outputImage else { + // Fall back if filter fails + let scaleX = size.width / image.extent.width + let scaleY = size.height / image.extent.height + let transform = CGAffineTransform(scaleX: scaleX, y: scaleY) + let scaled = image.transformed(by: transform) - let rescaled = filter.outputImage! + return scaled.cropped(to: CGRect(origin: .zero, size: size)) + } - // The image has a DoD larger than the requested size, so crop - // it to the desired size - return rescaled.cropped(to: CGRect(origin: .zero, size: size)) + // Calculate the crop rect to get exactly the requested size + // Scale height separately to match the target height + let heightScale = size.height / scaledImage.extent.height + let finalImage = scaledImage.transformed(by: CGAffineTransform(scaleX: 1.0, y: heightScale)) + + // Create a rect with the exact dimensions we want + let exactRect = CGRect( + x: 0, + y: 0, + width: size.width, + height: size.height + ) + + // Crop to ensure exact dimensions + return finalImage.cropped(to: exactRect) } /// Normalize the image using the given mean and standard deviation parameters. diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index 6f468579..9755c29d 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -668,11 +668,10 @@ public class Qwen25VLProcessor: UserInputProcessor { public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( MLXArray, THW ) { - // first apply the user requested resizing, etc. if any + // First apply the user requested resizing, etc. if any let images = images.map { MediaProcessing.apply($0, processing: processing) } // image_processing_qwen2_vl._preprocess - let size = images[0].extent.size let (resizedHeight, resizedWidth) = try QwenVL.targetSize( height: Int(size.height), width: Int(size.width), @@ -680,6 +679,7 @@ public class Qwen25VLProcessor: UserInputProcessor { minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + // Process images let processedImages = try images .map { @@ -696,7 +696,25 @@ public class Qwen25VLProcessor: UserInputProcessor { MediaProcessing.asMLXArray($0) } + // Calculate grid dimensions + let gridT = images.count + let gridH = resizedHeight / config.patchSize + let gridW = resizedWidth / config.patchSize + + // Ensure dimensions are valid + guard + resizedHeight % config.patchSize == 0 && resizedWidth % config.patchSize == 0 + && gridH % config.mergeSize == 0 && gridW % config.mergeSize == 0 + else { + throw VLMError.imageProcessingFailure( + "Image dimensions must be divisible by patch size and merge size") + } + + // Concatenate images and handle temporal patch size var patches = concatenated(processedImages) + let channel = patches.dim(1) + + // Pad to match temporal patch size if needed let mod = patches.dim(0) % config.temporalPatchSize if mod != 0 { let lastPatch = patches[-1, .ellipsis] @@ -704,34 +722,53 @@ public class Qwen25VLProcessor: UserInputProcessor { lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) patches = concatenated([patches, lastPatchRepeated]) } - let channel = patches.dim(1) - let gridT = patches.dim(0) / self.config.temporalPatchSize - let gridH = resizedHeight / self.config.patchSize - let gridW = resizedWidth / self.config.patchSize - - patches = patches.reshaped( - gridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) + + // Recalculate gridT after padding + let actualGridT = patches.dim(0) / config.temporalPatchSize + + // Calculate expected size for verification + let totalElements = patches.size + let expectedElements = + actualGridT * config.temporalPatchSize * channel * resizedHeight * resizedWidth + + // Try to reshape with careful dimension calculation + do { + patches = patches.reshaped( + actualGridT, + config.temporalPatchSize, + channel, + gridH / config.mergeSize, + config.mergeSize, + config.patchSize, + gridW / config.mergeSize, + config.mergeSize, + config.patchSize + ) + } catch { + // If reshape fails, provide detailed error + throw VLMError.imageProcessingFailure( + "Failed to reshape patches: \(error). Patches shape: \(patches.shape), " + + "Target shape: (\(actualGridT), \(config.temporalPatchSize), \(channel), " + + "\(gridH / config.mergeSize), \(config.mergeSize), \(config.patchSize), " + + "\(gridW / config.mergeSize), \(config.mergeSize), \(config.patchSize))" + ) + } + + // Continue with transpose and final reshape patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) let flattenedPatches = patches.reshaped( - gridT * gridH * gridW, - channel * config.temporalPatchSize * config.patchSize * config.patchSize + actualGridT * (gridH / config.mergeSize) * (gridW / config.mergeSize), + channel * config.temporalPatchSize * (config.mergeSize * config.patchSize) + * (config.mergeSize * config.patchSize) ) - return (flattenedPatches, .init(gridT, gridH, gridW)) + return (flattenedPatches, .init(actualGridT, gridH, gridW)) } public func prepare(input: UserInput) async throws -> LMInput { let messages = input.prompt.asMessages() + var promptTokens = try tokenizer.applyChatTemplate(messages: messages) // Text-only input @@ -748,10 +785,16 @@ public class Qwen25VLProcessor: UserInputProcessor { let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) processedImage = LMInput.ProcessedImage( pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) + if let imageFrames = processedImage?.frames { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) + do { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } catch { + print("Error in replacePaddingTokens: \(error)") + throw error + } } } @@ -772,10 +815,16 @@ public class Qwen25VLProcessor: UserInputProcessor { let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) processedVideo = LMInput.ProcessedVideo( pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) + if let videoFrames = processedVideo?.frames { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) + do { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } catch { + print("Error in video replacePaddingTokens: \(error)") + throw error + } } } diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 7c4c9ac9..4ed0a279 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -531,11 +531,10 @@ public class Qwen2VLProcessor: UserInputProcessor { public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( MLXArray, THW ) { - // first apply the user requested resizing, etc. if any + // First apply the user requested resizing, etc. if any let images = images.map { MediaProcessing.apply($0, processing: processing) } // image_processing_qwen2_vl._preprocess - let size = images[0].extent.size let (resizedHeight, resizedWidth) = try QwenVL.targetSize( height: Int(size.height), width: Int(size.width), @@ -543,6 +542,7 @@ public class Qwen2VLProcessor: UserInputProcessor { minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + // Process images let processedImages = try images .map { @@ -559,7 +559,25 @@ public class Qwen2VLProcessor: UserInputProcessor { MediaProcessing.asMLXArray($0) } + // Calculate grid dimensions + let gridT = images.count + let gridH = resizedHeight / config.patchSize + let gridW = resizedWidth / config.patchSize + + // Ensure dimensions are valid + guard + resizedHeight % config.patchSize == 0 && resizedWidth % config.patchSize == 0 + && gridH % config.mergeSize == 0 && gridW % config.mergeSize == 0 + else { + throw VLMError.imageProcessingFailure( + "Image dimensions must be divisible by patch size and merge size") + } + + // Concatenate images and handle temporal patch size var patches = concatenated(processedImages) + let channel = patches.dim(1) + + // Pad to match temporal patch size if needed let mod = patches.dim(0) % config.temporalPatchSize if mod != 0 { let lastPatch = patches[-1, .ellipsis] @@ -567,34 +585,53 @@ public class Qwen2VLProcessor: UserInputProcessor { lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) patches = concatenated([patches, lastPatchRepeated]) } - let channel = patches.dim(1) - let gridT = patches.dim(0) / self.config.temporalPatchSize - let gridH = resizedHeight / self.config.patchSize - let gridW = resizedWidth / self.config.patchSize - - patches = patches.reshaped( - gridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) + + // Recalculate gridT after padding + let actualGridT = patches.dim(0) / config.temporalPatchSize + + // Calculate expected size for verification + let totalElements = patches.size + let expectedElements = + actualGridT * config.temporalPatchSize * channel * resizedHeight * resizedWidth + + // Try to reshape with careful dimension calculation + do { + patches = patches.reshaped( + actualGridT, + config.temporalPatchSize, + channel, + gridH / config.mergeSize, + config.mergeSize, + config.patchSize, + gridW / config.mergeSize, + config.mergeSize, + config.patchSize + ) + } catch { + // If reshape fails, provide detailed error + throw VLMError.imageProcessingFailure( + "Failed to reshape patches: \(error). Patches shape: \(patches.shape), " + + "Target shape: (\(actualGridT), \(config.temporalPatchSize), \(channel), " + + "\(gridH / config.mergeSize), \(config.mergeSize), \(config.patchSize), " + + "\(gridW / config.mergeSize), \(config.mergeSize), \(config.patchSize))" + ) + } + + // Continue with transpose and final reshape patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) let flattenedPatches = patches.reshaped( - gridT * gridH * gridW, - channel * config.temporalPatchSize * config.patchSize * config.patchSize + actualGridT * (gridH / config.mergeSize) * (gridW / config.mergeSize), + channel * config.temporalPatchSize * (config.mergeSize * config.patchSize) + * (config.mergeSize * config.patchSize) ) - return (flattenedPatches, .init(gridT, gridH, gridW)) + return (flattenedPatches, .init(actualGridT, gridH, gridW)) } public func prepare(input: UserInput) async throws -> LMInput { let messages = input.prompt.asMessages() + var promptTokens = try tokenizer.applyChatTemplate(messages: messages) // Text-only input @@ -611,10 +648,16 @@ public class Qwen2VLProcessor: UserInputProcessor { let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) processedImage = LMInput.ProcessedImage( pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) + if let imageFrames = processedImage?.frames { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) + do { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } catch { + print("Error in replacePaddingTokens: \(error)") + throw error + } } } @@ -635,10 +678,16 @@ public class Qwen2VLProcessor: UserInputProcessor { let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) processedVideo = LMInput.ProcessedVideo( pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) + if let videoFrames = processedVideo?.frames { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) + do { + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) + } catch { + print("Error in video replacePaddingTokens: \(error)") + throw error + } } } diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift index a396893a..5c30f0b3 100644 --- a/Libraries/MLXVLM/Models/QwenVL.swift +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -169,6 +169,16 @@ public struct QwenVL { wBar = Int(round(Float(wBar) * scale / Float(factor))) * factor } + // Ensure dimensions are divisible by the factor + hBar = (hBar / factor) * factor + wBar = (wBar / factor) * factor + + // Final sanity check + if hBar <= 0 || wBar <= 0 { + throw VLMError.imageProcessingFailure( + "Invalid target dimensions: \(wBar) × \(hBar)") + } + return (hBar, wBar) } diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index f05d185e..6deac525 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -6,12 +6,27 @@ import MLX import MLXLMCommon import Tokenizers -public enum VLMError: Error { +public enum VLMError: LocalizedError { case imageRequired case maskRequired case singleImageAllowed case imageProcessingFailure(String) case processing(String) + + public var errorDescription: String? { + switch self { + case .imageRequired: + return "An image is required for this operation." + case .maskRequired: + return "A mask is required for this operation." + case .singleImageAllowed: + return "Only a single image is allowed for this operation." + case .imageProcessingFailure(let message): + return "Image processing failed: \(message)" + case .processing(let message): + return "Processing error: \(message)" + } + } } public struct BaseProcessorConfiguration: Codable, Sendable { From 6a917feedab020a8252ff810057c6ad0914d9ddb Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Mon, 10 Mar 2025 10:50:23 +0100 Subject: [PATCH 06/17] Fix downsampling --- Libraries/MLXVLM/Models/Qwen25VL.swift | 2 ++ Libraries/MLXVLM/Models/QwenVL.swift | 31 ++++++++------------------ 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index 9755c29d..35a3403e 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -679,6 +679,8 @@ public class Qwen25VLProcessor: UserInputProcessor { minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + print("config.size.maxPixels: \(config.size.maxPixels)") + // Process images let processedImages = try images diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift index 5c30f0b3..d3fcef0d 100644 --- a/Libraries/MLXVLM/Models/QwenVL.swift +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -121,6 +121,9 @@ public struct QwenVL { throws -> (Int, Int) { + print("Original dimensions: \(width) × \(height)") + print("Factor: \(factor), minPixels: \(minPixels), maxPixels: \(maxPixels)") + if height < factor { throw VLMError.imageProcessingFailure( "Height: \(height) must be larger than factor: \(factor)") @@ -134,44 +137,28 @@ public struct QwenVL { "Absolute aspect ratio must be smaller than 200: \(width) × \(height)") } - // Maximum allowed dimension for any single side to prevent buffer overflows - // This is important for portrait/landscape images with extreme aspect ratios - let maxDimension = 224 - var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) + print("After rounding to factor multiples: \(wBar) × \(hBar)") - // Start by scaling based on total pixel count + // Scale based on total pixel count if hBar * wBar > maxPixels { let beta = sqrt(Float(height * width) / Float(maxPixels)) hBar = Int(floor(Float(height) / beta / Float(factor))) * factor wBar = Int(floor(Float(width) / beta / Float(factor))) * factor + print("After scaling down for maxPixels: \(wBar) × \(hBar)") } else if hBar * wBar < minPixels { let beta = sqrt(Float(minPixels) / Float(height * width)) hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor - } - - // Additionally check if either dimension exceeds the maximum allowed - if hBar > maxDimension { - // Calculate how much we need to scale down height - let scale = Float(maxDimension) / Float(hBar) - // Apply that scale to both dimensions to maintain aspect ratio - hBar = Int(round(Float(hBar) * scale / Float(factor))) * factor - wBar = Int(round(Float(wBar) * scale / Float(factor))) * factor - } - - if wBar > maxDimension { - // Calculate how much we need to scale down width - let scale = Float(maxDimension) / Float(wBar) - // Apply that scale to both dimensions to maintain aspect ratio - hBar = Int(round(Float(hBar) * scale / Float(factor))) * factor - wBar = Int(round(Float(wBar) * scale / Float(factor))) * factor + print("After scaling up for minPixels: \(wBar) × \(hBar)") } // Ensure dimensions are divisible by the factor hBar = (hBar / factor) * factor wBar = (wBar / factor) * factor + print("Final dimensions: \(wBar) × \(hBar)") + print("Total pixels: \(wBar * hBar)") // Final sanity check if hBar <= 0 || wBar <= 0 { From 008d804f12f6b7a746c16586f488c968aaab811d Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Mon, 10 Mar 2025 11:06:44 +0100 Subject: [PATCH 07/17] Increase resize size --- Applications/VLMEval/ContentView.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Applications/VLMEval/ContentView.swift b/Applications/VLMEval/ContentView.swift index ed70cfe4..f539c946 100644 --- a/Applications/VLMEval/ContentView.swift +++ b/Applications/VLMEval/ContentView.swift @@ -421,7 +421,7 @@ class VLMEvaluator { ] } var userInput = UserInput(messages: messages, images: images, videos: videos) - userInput.processing.resize = .init(width: 448, height: 448) + userInput.processing.resize = .init(width: 1344, height: 1344) let input = try await context.processor.prepare(input: userInput) return try MLXLMCommon.generate( input: input, From 19e2aa8a1279ff2e03e3440c101095f17a8314a3 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Mon, 10 Mar 2025 14:02:44 +0100 Subject: [PATCH 08/17] Fix merge commit --- Libraries/MLXVLM/VLMModelFactory.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index cc7d9a4c..725c6b20 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -85,6 +85,7 @@ public class ModelTypeRegistry: @unchecked Sendable { [ "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), + "qwen2_5_vl": create(Qwen25VLConfiguration.self, Qwen25VL.init), "idefics3": create(Idefics3Configuration.self, Idefics3.init), ] } From aa59c6e738de8a89a5b8cb5bbeb2a8c435a9b8b2 Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 10 Mar 2025 16:52:45 -0700 Subject: [PATCH 09/17] use binary mask --- Libraries/MLXVLM/Models/Qwen25VL.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index 35a3403e..0bd00963 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -318,13 +318,13 @@ private enum Vision { // Create attention mask let attentionMask = full( [1, sequenceLength, sequenceLength], - values: -Float32.greatestFiniteMagnitude) + values: true) // Update mask for each sequence for i in 1 ..< cuSeqlens.size { let start = cuSeqlens[i - 1].item(Int.self) let end = cuSeqlens[i].item(Int.self) - attentionMask[0..., start ..< end, start ..< end] = MLXArray(0) + attentionMask[0..., start ..< end, start ..< end] = MLXArray(false) } q = q.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) From 71244993b112f7c0fa23032f049132dfc38cfa6c Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Tue, 11 Mar 2025 21:24:16 +0100 Subject: [PATCH 10/17] Bicubic works --- Libraries/MLXVLM/MediaProcessing.swift | 42 +++++++++---------------- Libraries/MLXVLM/Models/Idefics3.swift | 2 +- Libraries/MLXVLM/Models/Paligemma.swift | 4 +-- Libraries/MLXVLM/Models/Qwen25VL.swift | 2 +- Libraries/MLXVLM/Models/Qwen2VL.swift | 2 +- 5 files changed, 20 insertions(+), 32 deletions(-) diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index 4f3f94cb..ec632916 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -58,43 +58,32 @@ public enum MediaProcessing { min(other.width / size.width, other.height / size.height) } + enum MediaProcessingError: LocalizedError { + case transformFailed + + var errorDescription: String? { + "Failed to transform image" + } + } + /// Resample the image using bicubic interpolation. /// - Parameters: /// - image: The image to resample /// - size: The target size /// - Returns: The resampled image - public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { - // First, create a CIFilter for precise resampling - guard let filter = CIFilter(name: "CILanczosScaleTransform") else { - // Fall back to affine transform if filter isn't available - let scaleX = size.width / image.extent.width - let scaleY = size.height / image.extent.height - let transform = CGAffineTransform(scaleX: scaleX, y: scaleY) - let scaled = image.transformed(by: transform) - - // Force exact dimensions by cropping - return scaled.cropped(to: CGRect(origin: .zero, size: size)) - } - - filter.setValue(image, forKey: kCIInputImageKey) - filter.setValue(size.width / image.extent.width, forKey: kCIInputScaleKey) - filter.setValue(1.0, forKey: kCIInputAspectRatioKey) - + public static func resampleBicubic(_ image: CIImage, to size: CGSize) throws -> CIImage { + // Create a bicubic scale filter + let filter = CIFilter.bicubicScaleTransform() + filter.inputImage = image + filter.scale = Float(size.width / image.extent.width) + filter.aspectRatio = 1.0 guard let scaledImage = filter.outputImage else { - // Fall back if filter fails - let scaleX = size.width / image.extent.width - let scaleY = size.height / image.extent.height - let transform = CGAffineTransform(scaleX: scaleX, y: scaleY) - let scaled = image.transformed(by: transform) - - return scaled.cropped(to: CGRect(origin: .zero, size: size)) + throw MediaProcessingError.transformFailed } - // Calculate the crop rect to get exactly the requested size // Scale height separately to match the target height let heightScale = size.height / scaledImage.extent.height let finalImage = scaledImage.transformed(by: CGAffineTransform(scaleX: 1.0, y: heightScale)) - // Create a rect with the exact dimensions we want let exactRect = CGRect( x: 0, @@ -102,7 +91,6 @@ public enum MediaProcessing { width: size.width, height: size.height ) - // Crop to ensure exact dimensions return finalImage.cropped(to: exactRect) } diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift index 9effd20d..2ff2721e 100644 --- a/Libraries/MLXVLM/Models/Idefics3.swift +++ b/Libraries/MLXVLM/Models/Idefics3.swift @@ -837,7 +837,7 @@ public class Idefics3Processor: UserInputProcessor { height: fixedImageSize ) image = MediaProcessing.apply(image, processing: input.processing) - image = MediaProcessing.resampleBicubic(image, to: targetSize) + image = try MediaProcessing.resampleBicubic(image, to: targetSize) image = MediaProcessing.normalize( image, mean: config.imageMeanTuple, diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift index 18c94113..7c66861c 100644 --- a/Libraries/MLXVLM/Models/Paligemma.swift +++ b/Libraries/MLXVLM/Models/Paligemma.swift @@ -451,7 +451,7 @@ public class PaliGemmaProcessor: UserInputProcessor { self.tokenizer = tokenizer } - private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray { + private func prepare(image: CIImage, processing: UserInput.Processing?) throws -> MLXArray { // based on image_processing_siglip from transformers var image = image @@ -463,7 +463,7 @@ public class PaliGemmaProcessor: UserInputProcessor { // apply user instructions image = MediaProcessing.apply(image, processing: processing) - image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) + image = try MediaProcessing.resampleBicubic(image, to: config.size.cgSize) image = MediaProcessing.normalize( image, mean: config.imageMeanTuple, std: config.imageStdTuple) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index 0bd00963..1ddbcef4 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -688,7 +688,7 @@ public class Qwen25VLProcessor: UserInputProcessor { MediaProcessing.inSRGBToneCurveSpace($0) } .map { - return MediaProcessing.resampleBicubic($0, to: resizedSize) + return try MediaProcessing.resampleBicubic($0, to: resizedSize) } .map { MediaProcessing.normalize( diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 4ed0a279..511b1aa6 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -549,7 +549,7 @@ public class Qwen2VLProcessor: UserInputProcessor { MediaProcessing.inSRGBToneCurveSpace($0) } .map { - return MediaProcessing.resampleBicubic($0, to: resizedSize) + return try MediaProcessing.resampleBicubic($0, to: resizedSize) } .map { MediaProcessing.normalize( From d60188aafa9ff84c254755105bf94530bd1b28a1 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Tue, 11 Mar 2025 21:25:49 +0100 Subject: [PATCH 11/17] Fix MediaProcessingError --- Libraries/MLXVLM/MediaProcessing.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index ec632916..0726450d 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -62,7 +62,9 @@ public enum MediaProcessing { case transformFailed var errorDescription: String? { - "Failed to transform image" + switch self { + case .transformFailed: "Failed to transform image" + } } } From d37f85ba2bb9069dfd4f789e4bc200ec089c8c74 Mon Sep 17 00:00:00 2001 From: David Koski Date: Tue, 11 Mar 2025 14:26:27 -0700 Subject: [PATCH 12/17] use computed height rather than extent height (that includes partial pixels from the DOD) --- Libraries/MLXVLM/MediaProcessing.swift | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index 0726450d..fbb5af90 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -62,9 +62,9 @@ public enum MediaProcessing { case transformFailed var errorDescription: String? { - switch self { + switch self { case .transformFailed: "Failed to transform image" - } + } } } @@ -75,16 +75,19 @@ public enum MediaProcessing { /// - Returns: The resampled image public static func resampleBicubic(_ image: CIImage, to size: CGSize) throws -> CIImage { // Create a bicubic scale filter + + let scale = size.width / image.extent.width + let filter = CIFilter.bicubicScaleTransform() filter.inputImage = image - filter.scale = Float(size.width / image.extent.width) + filter.scale = Float(scale) filter.aspectRatio = 1.0 guard let scaledImage = filter.outputImage else { throw MediaProcessingError.transformFailed } // Calculate the crop rect to get exactly the requested size // Scale height separately to match the target height - let heightScale = size.height / scaledImage.extent.height + let heightScale = size.height / (image.extent.height * scale) let finalImage = scaledImage.transformed(by: CGAffineTransform(scaleX: 1.0, y: heightScale)) // Create a rect with the exact dimensions we want let exactRect = CGRect( From 20ac074dd6b5db89bfd95e39506484e6b8a02157 Mon Sep 17 00:00:00 2001 From: David Koski Date: Wed, 12 Mar 2025 07:53:14 -0700 Subject: [PATCH 13/17] use aspect ratio to match target size --- Libraries/MLXVLM/MediaProcessing.swift | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index fbb5af90..e5c1c9b8 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -76,19 +76,16 @@ public enum MediaProcessing { public static func resampleBicubic(_ image: CIImage, to size: CGSize) throws -> CIImage { // Create a bicubic scale filter - let scale = size.width / image.extent.width + let yScale = size.height / image.extent.height + let xScale = size.width / image.extent.width let filter = CIFilter.bicubicScaleTransform() filter.inputImage = image - filter.scale = Float(scale) - filter.aspectRatio = 1.0 + filter.scale = Float(yScale) + filter.aspectRatio = Float(xScale / yScale) guard let scaledImage = filter.outputImage else { throw MediaProcessingError.transformFailed } - // Calculate the crop rect to get exactly the requested size - // Scale height separately to match the target height - let heightScale = size.height / (image.extent.height * scale) - let finalImage = scaledImage.transformed(by: CGAffineTransform(scaleX: 1.0, y: heightScale)) // Create a rect with the exact dimensions we want let exactRect = CGRect( x: 0, @@ -97,7 +94,7 @@ public enum MediaProcessing { height: size.height ) // Crop to ensure exact dimensions - return finalImage.cropped(to: exactRect) + return scaledImage.cropped(to: exactRect) } /// Normalize the image using the given mean and standard deviation parameters. From ee978b0056b1e9cea2120a12e73613946281ad3a Mon Sep 17 00:00:00 2001 From: David Koski Date: Wed, 12 Mar 2025 08:57:21 -0700 Subject: [PATCH 14/17] hoist attention mask generation to VisionModel -- avoid recomputing the mask 32 times --- Libraries/MLXVLM/Models/Qwen25VL.swift | 68 ++++++++++++++------------ 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index 1ddbcef4..b53ccb08 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -300,7 +300,7 @@ private enum Vision { } public func callAsFunction( - _ x: MLXArray, cuSeqlens: MLXArray, rotaryPositionEmbedding: MLXArray + _ x: MLXArray, attentionMask: MLXArray, rotaryPositionEmbedding: MLXArray ) -> MLXArray { let sequenceLength = x.dim(0) @@ -315,18 +315,6 @@ private enum Vision { q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) - // Create attention mask - let attentionMask = full( - [1, sequenceLength, sequenceLength], - values: true) - - // Update mask for each sequence - for i in 1 ..< cuSeqlens.size { - let start = cuSeqlens[i - 1].item(Int.self) - let end = cuSeqlens[i].item(Int.self) - attentionMask[0..., start ..< end, start ..< end] = MLXArray(false) - } - q = q.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) k = k.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) v = v.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3) @@ -377,13 +365,13 @@ private enum Vision { } func callAsFunction( - _ hiddenStates: MLXArray, cuSeqlens: MLXArray, rotaryPositionEmbedding: MLXArray + _ hiddenStates: MLXArray, attentionMask: MLXArray, rotaryPositionEmbedding: MLXArray ) -> MLXArray { var hiddenStates = hiddenStates + attention( norm1(hiddenStates), - cuSeqlens: cuSeqlens, + attentionMask: attentionMask, rotaryPositionEmbedding: rotaryPositionEmbedding ) hiddenStates = hiddenStates + mlp(norm2(hiddenStates)) @@ -558,6 +546,22 @@ private enum Vision { return (combinedWindowIndex, uniqueCuWindowSeqlens) } + func attentionMask(sequenceLength: Int, cuSeqlens: MLXArray) -> MLXArray { + // Create attention mask + let attentionMask = full( + [1, sequenceLength, sequenceLength], + values: true) + + // Update mask for each sequence + for i in 1 ..< cuSeqlens.size { + let start = cuSeqlens[i - 1].item(Int.self) + let end = cuSeqlens[i].item(Int.self) + attentionMask[0..., start ..< end, start ..< end] = MLXArray(false) + } + + return attentionMask + } + public func callAsFunction(_ hiddenStates: MLXArray, frames: [THW]) -> MLXArray { var hiddenStates = patchEmbed(hiddenStates) let rotaryPosEmb = rotaryPositionEmbedding(frames) @@ -565,19 +569,8 @@ private enum Vision { // Get window indices and sequence lengths let (windowIndex, cuWindowSeqlens) = getWindowIndex(frames) - // Reshape and reindex hidden states + // prepare attention masks let seqLen = hiddenStates.dim(0) - hiddenStates = hiddenStates.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) - hiddenStates = hiddenStates[windowIndex, 0..., 0...] - hiddenStates = hiddenStates.reshaped(seqLen, -1) - - // Reshape and reindex rotary position embeddings - var rotaryPosEmbReshaped = rotaryPosEmb.reshaped( - seqLen / spatialMergeUnit, spatialMergeUnit, -1) - rotaryPosEmbReshaped = rotaryPosEmbReshaped[windowIndex, 0..., 0...] - rotaryPosEmbReshaped = rotaryPosEmbReshaped.reshaped(seqLen, -1) - - // Calculate cumulative sequence lengths for full attention var cuSeqlens = [0] for frame in frames { let seqLen = frame.h * frame.w @@ -588,15 +581,30 @@ private enum Vision { } let cuSeqlensArray = MLXArray(cuSeqlens) + let fullAttentionMask = attentionMask(sequenceLength: seqLen, cuSeqlens: cuSeqlensArray) + let windowAttentionMask = attentionMask( + sequenceLength: seqLen, cuSeqlens: cuWindowSeqlens) + + // Reshape and reindex hidden states + hiddenStates = hiddenStates.reshaped(seqLen / spatialMergeUnit, spatialMergeUnit, -1) + hiddenStates = hiddenStates[windowIndex, 0..., 0...] + hiddenStates = hiddenStates.reshaped(seqLen, -1) + + // Reshape and reindex rotary position embeddings + var rotaryPosEmbReshaped = rotaryPosEmb.reshaped( + seqLen / spatialMergeUnit, spatialMergeUnit, -1) + rotaryPosEmbReshaped = rotaryPosEmbReshaped[windowIndex, 0..., 0...] + rotaryPosEmbReshaped = rotaryPosEmbReshaped.reshaped(seqLen, -1) + // Process through blocks for (i, block) in blocks.enumerated() { // Use full attention for specific blocks, window attention for others - let cuSeqlensNow = - fullattBlockIndexes.contains(i) ? cuSeqlensArray : cuWindowSeqlens + let attentionMask = + fullattBlockIndexes.contains(i) ? fullAttentionMask : windowAttentionMask hiddenStates = block( hiddenStates, - cuSeqlens: cuSeqlensNow, + attentionMask: attentionMask, rotaryPositionEmbedding: rotaryPosEmbReshaped ) } From 52610723a13f1ac6e90f2daadb96e6949adf978f Mon Sep 17 00:00:00 2001 From: David Koski Date: Wed, 12 Mar 2025 10:14:05 -0700 Subject: [PATCH 15/17] the bool mask didn't work correctly -- use int8 --- Libraries/MLXVLM/Models/Qwen25VL.swift | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index b53ccb08..b2b9f50d 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -550,13 +550,14 @@ private enum Vision { // Create attention mask let attentionMask = full( [1, sequenceLength, sequenceLength], - values: true) + values: Int8(-127)) // Update mask for each sequence - for i in 1 ..< cuSeqlens.size { - let start = cuSeqlens[i - 1].item(Int.self) - let end = cuSeqlens[i].item(Int.self) - attentionMask[0..., start ..< end, start ..< end] = MLXArray(false) + let cuSeqlens = cuSeqlens.asArray(Int.self) + for i in 1 ..< cuSeqlens.count { + let start = cuSeqlens[i - 1] + let end = cuSeqlens[i] + attentionMask[0..., start ..< end, start ..< end] = MLXArray(Int8(0)) } return attentionMask From 97f6ba1a3c395b29ed22bf28680ed8276a703a2a Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 14 Apr 2025 11:10:28 -0700 Subject: [PATCH 16/17] finish merging main -- some minor refactoring to make qwen2.5 match qwen2 in terms of where it is currently factored --- Libraries/MLXVLM/MediaProcessing.swift | 65 +++------- Libraries/MLXVLM/Models/Qwen25VL.swift | 118 +++++------------ Libraries/MLXVLM/Models/Qwen2VL.swift | 120 +++--------------- Libraries/MLXVLM/Models/QwenVL.swift | 47 +++++++ Libraries/MLXVLM/VLMModelFactory.swift | 2 +- .../xcshareddata/swiftpm/Package.resolved | 2 +- 6 files changed, 124 insertions(+), 230 deletions(-) diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift index 8699b628..63c0efb4 100644 --- a/Libraries/MLXVLM/MediaProcessing.swift +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -16,7 +16,6 @@ public struct ProcessedFrames { let totalDuration: CMTime } -// TODO: verify working color space, rendering color space private let context = CIContext() /// Collection of methods for processing media (images, video, etc.). @@ -76,51 +75,29 @@ public enum MediaProcessing { return Float(1 / inputAspectRatio * desiredAspectRatio) } - /// Resample the image using bicubic interpolation. - public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { - let filter = CIFilter.bicubicScaleTransform() - let extent = image.extent.size - - filter.inputImage = image - - // set the aspect ratio to match the aspect ratio of the target - filter.aspectRatio = aspectRatioForResample(image, size: size) - - // that image is now the aspect ratio of the target and the size - // of the shorter dimension - let scale: CGFloat - if extent.width < extent.height { - scale = size.width / extent.width - } else { - scale = size.height / extent.height - } - filter.scale = Float(scale) - - let rescaled = filter.outputImage! - - // the image has a DoD larger than the requested size so crop - // it to the desired size - return rescaled.cropped(to: CGRect(origin: .zero, size: size)) - } - /// Resample the image using Lanczos interpolation. static public func resampleLanczos(_ image: CIImage, to size: CGSize) -> CIImage { - let filter = CIFilter.lanczosScaleTransform() - let extent = image.extent.size + // Create a bicubic scale filter + + let yScale = size.height / image.extent.height + let xScale = size.width / image.extent.width + let filter = CIFilter.lanczosScaleTransform() filter.inputImage = image + filter.scale = Float(yScale) + filter.aspectRatio = Float(xScale / yScale) + let scaledImage = filter.outputImage! - // set the aspect ratio to match the aspect ratio of the target - filter.aspectRatio = aspectRatioForResample(image, size: size) + // Create a rect with the exact dimensions we want + let exactRect = CGRect( + x: 0, + y: 0, + width: size.width, + height: size.height + ) - // that image is now the aspect ratio of the target and the size - // of the shorter dimension - let scale: CGFloat - if extent.width < extent.height { - scale = size.width / extent.width - } else { - scale = size.height / extent.height - } + // Crop to ensure exact dimensions + return scaledImage.cropped(to: exactRect) } /// Resample the image using bicubic interpolation. @@ -128,7 +105,7 @@ public enum MediaProcessing { /// - image: The image to resample /// - size: The target size /// - Returns: The resampled image - public static func resampleBicubic(_ image: CIImage, to size: CGSize) throws -> CIImage { + public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { // Create a bicubic scale filter let yScale = size.height / image.extent.height @@ -138,9 +115,8 @@ public enum MediaProcessing { filter.inputImage = image filter.scale = Float(yScale) filter.aspectRatio = Float(xScale / yScale) - guard let scaledImage = filter.outputImage else { - throw MediaProcessingError.transformFailed - } + let scaledImage = filter.outputImage! + // Create a rect with the exact dimensions we want let exactRect = CGRect( x: 0, @@ -148,6 +124,7 @@ public enum MediaProcessing { width: size.width, height: size.height ) + // Crop to ensure exact dimensions return scaledImage.cropped(to: exactRect) } diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index b2b9f50d..f13a07e3 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -674,6 +674,13 @@ public class Qwen25VLProcessor: UserInputProcessor { self.tokenizer = tokenizer } + func preprocess(image: CIImage, resizedSize: CGSize) -> CIImage { + image + .toSRGB() + .resampled(to: resizedSize, method: .bicubic) + .normalized(mean: config.imageMeanTuple, std: config.imageStdTuple) + } + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( MLXArray, THW ) { @@ -707,74 +714,9 @@ public class Qwen25VLProcessor: UserInputProcessor { MediaProcessing.asMLXArray($0) } - // Calculate grid dimensions - let gridT = images.count - let gridH = resizedHeight / config.patchSize - let gridW = resizedWidth / config.patchSize - - // Ensure dimensions are valid - guard - resizedHeight % config.patchSize == 0 && resizedWidth % config.patchSize == 0 - && gridH % config.mergeSize == 0 && gridW % config.mergeSize == 0 - else { - throw VLMError.imageProcessingFailure( - "Image dimensions must be divisible by patch size and merge size") - } - - // Concatenate images and handle temporal patch size - var patches = concatenated(processedImages) - let channel = patches.dim(1) - - // Pad to match temporal patch size if needed - let mod = patches.dim(0) % config.temporalPatchSize - if mod != 0 { - let lastPatch = patches[-1, .ellipsis] - let lastPatchRepeated = tiled( - lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) - patches = concatenated([patches, lastPatchRepeated]) - } - - // Recalculate gridT after padding - let actualGridT = patches.dim(0) / config.temporalPatchSize - - // Calculate expected size for verification - let totalElements = patches.size - let expectedElements = - actualGridT * config.temporalPatchSize * channel * resizedHeight * resizedWidth - - // Try to reshape with careful dimension calculation - do { - patches = patches.reshaped( - actualGridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) - } catch { - // If reshape fails, provide detailed error - throw VLMError.imageProcessingFailure( - "Failed to reshape patches: \(error). Patches shape: \(patches.shape), " - + "Target shape: (\(actualGridT), \(config.temporalPatchSize), \(channel), " - + "\(gridH / config.mergeSize), \(config.mergeSize), \(config.patchSize), " - + "\(gridW / config.mergeSize), \(config.mergeSize), \(config.patchSize))" - ) - } - - // Continue with transpose and final reshape - patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) - - let flattenedPatches = patches.reshaped( - actualGridT * (gridH / config.mergeSize) * (gridW / config.mergeSize), - channel * config.temporalPatchSize * (config.mergeSize * config.patchSize) - * (config.mergeSize * config.patchSize) - ) - - return (flattenedPatches, .init(actualGridT, gridH, gridW)) + return try QwenVL.patchify( + images: processedImages, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) } public func prepare(input: UserInput) async throws -> LMInput { @@ -812,30 +754,40 @@ public class Qwen25VLProcessor: UserInputProcessor { // Process videos if any var processedVideo: LMInput.ProcessedVideo? if !input.videos.isEmpty { - var videosAsImageSequences = [[CIImage]]() + var videosAsImageSequences = [[MLXArray]]() + var resizedSize: CGSize = .zero for video in input.videos { - if let imageSequence = try? await MediaProcessing.asCIImageSequence( - video.asAVAsset(), samplesPerSecond: 2) - { - videosAsImageSequences.append(imageSequence) + let imageSequence = try await MediaProcessing.asProcessedSequence( + video.asAVAsset(), samplesPerSecond: 2 + ) { frame in + // first apply the user requested resizing, etc. if any + let resizedImage = MediaProcessing.apply( + frame.frame, processing: input.processing) + if resizedSize == .zero { + let size = resizedImage.extent.size + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.minPixels, maxPixels: config.maxPixels) + resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + } + let processedImage = preprocess(image: resizedImage, resizedSize: resizedSize) + return VideoFrame(frame: processedImage, timeStamp: frame.timeStamp) } + videosAsImageSequences.append(imageSequence.frames) } let videoPixelsAndFrames = try videosAsImageSequences.map { - try preprocess(images: $0, processing: input.processing) + try QwenVL.patchify( + images: $0, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) } let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) processedVideo = LMInput.ProcessedVideo( pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) - if let videoFrames = processedVideo?.frames { - do { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) - } catch { - print("Error in video replacePaddingTokens: \(error)") - throw error - } + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index d7fc9c04..8b20d576 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -528,38 +528,6 @@ public class Qwen2VLProcessor: UserInputProcessor { self.tokenizer = tokenizer } - // image_processing_qwen2_vl.smart_resize - private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) - throws -> (Int, Int) - { - if height < factor { - throw VLMError.imageProcessingFailure( - "height: \(height) must be larger than factor: \(factor)") - } - if width < factor { - throw VLMError.imageProcessingFailure( - "width: \(width) must be larger than factor: \(factor)") - } - if max(height, width) / min(height, width) > 200 { - throw VLMError.imageProcessingFailure( - "absolute aspect ratio must be smaller than 200: \(width)x\(height)") - } - - var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) - var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) - - if hBar * wBar > maxPixels { - let beta = sqrt(Float(height * width) / Float(maxPixels)) - hBar = Int(floor(Float(height) / beta / Float(factor))) * factor - wBar = Int(floor(Float(width) / beta / Float(factor))) * factor - } else if hBar * wBar < minPixels { - let beta = sqrt(Float(minPixels) / Float(height * width)) - hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor - wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor - } - return (hBar, wBar) - } - func preprocess(image: CIImage, resizedSize: CGSize) -> CIImage { image .toSRGB() @@ -570,10 +538,11 @@ public class Qwen2VLProcessor: UserInputProcessor { public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( MLXArray, THW ) { - // First apply the user requested resizing, etc. if any + // first apply the user requested resizing, etc. if any let images = images.map { MediaProcessing.apply($0, processing: processing) } // image_processing_qwen2_vl._preprocess + let size = images[0].extent.size let (resizedHeight, resizedWidth) = try QwenVL.targetSize( height: Int(size.height), width: Int(size.width), @@ -581,60 +550,17 @@ public class Qwen2VLProcessor: UserInputProcessor { minPixels: config.minPixels, maxPixels: config.maxPixels) let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) - let mod = patches.dim(0) % config.temporalPatchSize - if mod != 0 { - let lastPatch = patches[-1, .ellipsis] - let lastPatchRepeated = tiled( - lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1]) - patches = concatenated([patches, lastPatchRepeated]) - } - - // Recalculate gridT after padding - let actualGridT = patches.dim(0) / config.temporalPatchSize - - // Calculate expected size for verification - let totalElements = patches.size - let expectedElements = - actualGridT * config.temporalPatchSize * channel * resizedHeight * resizedWidth - - // Try to reshape with careful dimension calculation - do { - patches = patches.reshaped( - actualGridT, - config.temporalPatchSize, - channel, - gridH / config.mergeSize, - config.mergeSize, - config.patchSize, - gridW / config.mergeSize, - config.mergeSize, - config.patchSize - ) - } catch { - // If reshape fails, provide detailed error - throw VLMError.imageProcessingFailure( - "Failed to reshape patches: \(error). Patches shape: \(patches.shape), " - + "Target shape: (\(actualGridT), \(config.temporalPatchSize), \(channel), " - + "\(gridH / config.mergeSize), \(config.mergeSize), \(config.patchSize), " - + "\(gridW / config.mergeSize), \(config.mergeSize), \(config.patchSize))" - ) + let processedImages = try images.map { image in + preprocess(image: image, resizedSize: resizedSize).asMLXArray() } - // Continue with transpose and final reshape - patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) - - let flattenedPatches = patches.reshaped( - actualGridT * (gridH / config.mergeSize) * (gridW / config.mergeSize), - channel * config.temporalPatchSize * (config.mergeSize * config.patchSize) - * (config.mergeSize * config.patchSize) - ) - - return (flattenedPatches, .init(actualGridT, gridH, gridW)) + return try QwenVL.patchify( + images: processedImages, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) } public func prepare(input: UserInput) async throws -> LMInput { let messages = input.prompt.asMessages() - var promptTokens = try tokenizer.applyChatTemplate(messages: messages) // Text-only input @@ -651,16 +577,10 @@ public class Qwen2VLProcessor: UserInputProcessor { let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) processedImage = LMInput.ProcessedImage( pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) - if let imageFrames = processedImage?.frames { - do { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) - } catch { - print("Error in replacePaddingTokens: \(error)") - throw error - } + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } @@ -678,7 +598,7 @@ public class Qwen2VLProcessor: UserInputProcessor { frame.frame, processing: input.processing) if resizedSize == .zero { let size = resizedImage.extent.size - let (resizedHeight, resizedWidth) = try targetSize( + let (resizedHeight, resizedWidth) = try QwenVL.targetSize( height: Int(size.height), width: Int(size.width), factor: config.patchSize * config.mergeSize, minPixels: config.minPixels, maxPixels: config.maxPixels) @@ -689,20 +609,18 @@ public class Qwen2VLProcessor: UserInputProcessor { } videosAsImageSequences.append(imageSequence.frames) } - let videoPixelsAndFrames = try videosAsImageSequences.map(patchify) + let videoPixelsAndFrames = try videosAsImageSequences.map { + try QwenVL.patchify( + images: $0, mergeSize: config.mergeSize, patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize) + } let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) processedVideo = LMInput.ProcessedVideo( pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) - if let videoFrames = processedVideo?.frames { - do { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) - } catch { - print("Error in video replacePaddingTokens: \(error)") - throw error - } + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift index d3fcef0d..8d12eaea 100644 --- a/Libraries/MLXVLM/Models/QwenVL.swift +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -205,4 +205,51 @@ public struct QwenVL { } return result } + + static func patchify(images: [MLXArray], mergeSize: Int, patchSize: Int, temporalPatchSize: Int) + throws -> ( + MLXArray, THW + ) + { + guard let firstImage = images.first else { + throw VLMError.imageProcessingFailure("No images in video sequence") + } + let resizedHeight = firstImage.dim(-2) + let resizedWidth = firstImage.dim(-1) + var patches = concatenated(images) + + // Pad to match temporal patch size if needed + let mod = patches.dim(0) % temporalPatchSize + if mod != 0 { + let lastPatch = patches[-1, .ellipsis] + let lastPatchRepeated = tiled( + lastPatch, repetitions: [temporalPatchSize - mod, 1, 1, 1]) + patches = concatenated([patches, lastPatchRepeated]) + } + let channel = patches.dim(1) + let gridT = patches.dim(0) / temporalPatchSize + let gridH = resizedHeight / patchSize + let gridW = resizedWidth / patchSize + + patches = patches.reshaped( + gridT, + temporalPatchSize, + channel, + gridH / mergeSize, + mergeSize, + patchSize, + gridW / mergeSize, + mergeSize, + patchSize + ) + patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) + + let flattenedPatches = patches.reshaped( + gridT * gridH * gridW, + channel * temporalPatchSize * patchSize * patchSize + ) + + return (flattenedPatches, .init(gridT, gridH, gridW)) + } + } diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index e7a036b1..0079fc4f 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -144,7 +144,7 @@ public class VLMRegistry: AbstractModelRegistry, @unchecked Sendable { id: "mlx-community/SmolVLM-Instruct-4bit", defaultPrompt: "Describe the image in English" ) - + static public let smolvlm = ModelConfiguration( id: "HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx", defaultPrompt: diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 90cab490..9553e9a3 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "369f2014f0f4b1785f2b2642d3b4a3cbd3220a38b18d03ac9d74965949a0f0ba", + "originHash" : "0777c427cd29bb45ee52257882d29c3c2063039870a79b9b91a32154eb35f7b5", "pins" : [ { "identity" : "gzipswift", From 184d0b9af912c6f3157a2e886af14378d5d1bac9 Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 14 Apr 2025 11:17:34 -0700 Subject: [PATCH 17/17] remove/convert debug printing in the model --- Libraries/MLXVLM/Models/Qwen25VL.swift | 13 +++---------- Libraries/MLXVLM/Models/QwenVL.swift | 18 +++++++++++------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index f13a07e3..4ddce506 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -695,8 +695,6 @@ public class Qwen25VLProcessor: UserInputProcessor { minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) - print("config.size.maxPixels: \(config.size.maxPixels)") - // Process images let processedImages = try images @@ -740,14 +738,9 @@ public class Qwen25VLProcessor: UserInputProcessor { pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) if let imageFrames = processedImage?.frames { - do { - promptTokens = try QwenVL.replacePaddingTokens( - in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", - mergeSize: config.mergeSize, tokenizer: tokenizer) - } catch { - print("Error in replacePaddingTokens: \(error)") - throw error - } + promptTokens = try QwenVL.replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>", + mergeSize: config.mergeSize, tokenizer: tokenizer) } } diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift index 8d12eaea..e1bf168c 100644 --- a/Libraries/MLXVLM/Models/QwenVL.swift +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -9,6 +9,10 @@ import Tokenizers // MARK: - Common Utilities for Qwen 2 VL and Qwen 2.5 VL +private func debug(_ message: @autoclosure () -> String) { + // print(message()) +} + public struct QwenVL { /// Rotates half the hidden dims of the input static func rotateHalf(_ x: MLXArray) -> MLXArray { @@ -121,8 +125,8 @@ public struct QwenVL { throws -> (Int, Int) { - print("Original dimensions: \(width) × \(height)") - print("Factor: \(factor), minPixels: \(minPixels), maxPixels: \(maxPixels)") + debug("Original dimensions: \(width) × \(height)") + debug("Factor: \(factor), minPixels: \(minPixels), maxPixels: \(maxPixels)") if height < factor { throw VLMError.imageProcessingFailure( @@ -139,26 +143,26 @@ public struct QwenVL { var hBar = max(factor, Int(round(Float(height) / Float(factor))) * factor) var wBar = max(factor, Int(round(Float(width) / Float(factor))) * factor) - print("After rounding to factor multiples: \(wBar) × \(hBar)") + debug("After rounding to factor multiples: \(wBar) × \(hBar)") // Scale based on total pixel count if hBar * wBar > maxPixels { let beta = sqrt(Float(height * width) / Float(maxPixels)) hBar = Int(floor(Float(height) / beta / Float(factor))) * factor wBar = Int(floor(Float(width) / beta / Float(factor))) * factor - print("After scaling down for maxPixels: \(wBar) × \(hBar)") + debug("After scaling down for maxPixels: \(wBar) × \(hBar)") } else if hBar * wBar < minPixels { let beta = sqrt(Float(minPixels) / Float(height * width)) hBar = Int(ceil(Float(height) * beta / Float(factor))) * factor wBar = Int(ceil(Float(width) * beta / Float(factor))) * factor - print("After scaling up for minPixels: \(wBar) × \(hBar)") + debug("After scaling up for minPixels: \(wBar) × \(hBar)") } // Ensure dimensions are divisible by the factor hBar = (hBar / factor) * factor wBar = (wBar / factor) * factor - print("Final dimensions: \(wBar) × \(hBar)") - print("Total pixels: \(wBar * hBar)") + debug("Final dimensions: \(wBar) × \(hBar)") + debug("Total pixels: \(wBar * hBar)") // Final sanity check if hBar <= 0 || wBar <= 0 {