From dbf2ead1812d49167829e835eb0b4c5c32db7853 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Sun, 26 Jan 2025 14:23:50 +0100 Subject: [PATCH 1/6] Update swift-transformers --- .../xcshareddata/swiftpm/Package.resolved | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 8f722eba..5245a71d 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -73,6 +73,15 @@ "version" : "1.1.4" } }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", + "version" : "1.1.4" + } + }, { "identity" : "swift-markdown-ui", "kind" : "remoteSourceControl", From 39800bc15d3b92bfce138efcce3d9b390c87aa07 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Sun, 26 Jan 2025 14:24:28 +0100 Subject: [PATCH 2/6] Use chat template for Qwen 2 VL --- Applications/VLMEval/ContentView.swift | 30 ++- Libraries/MLXLLM/LLMModelFactory.swift | 2 +- Libraries/MLXLMCommon/LanguageModel.swift | 14 +- Libraries/MLXLMCommon/UserInput.swift | 11 +- Libraries/MLXVLM/Models/Idefics3.swift | 2 +- Libraries/MLXVLM/Models/Paligemma.swift | 2 +- Libraries/MLXVLM/Models/Qwen2VL.swift | 230 +++++++++++----------- Libraries/MLXVLM/VLMModelFactory.swift | 1 + 8 files changed, 156 insertions(+), 136 deletions(-) diff --git a/Applications/VLMEval/ContentView.swift b/Applications/VLMEval/ContentView.swift index 33b4debe..041fa097 100644 --- a/Applications/VLMEval/ContentView.swift +++ b/Applications/VLMEval/ContentView.swift @@ -383,9 +383,33 @@ class VLMEvaluator { MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) let result = try await modelContainer.perform { context in - let images: [UserInput.Image] = image != nil ? [.ciImage(image!)] : [] - let videos: [UserInput.Video] = videoURL != nil ? [.url(videoURL!)] : [] - var userInput = UserInput(prompt: prompt, images: images, videos: videos) + let images: [UserInput.Image] = + if let image { + [UserInput.Image.ciImage(image)] + } else { + [] + } + let videos: [UserInput.Video] = + if let videoURL { + [.url(videoURL)] + } else { + [] + } + var userInput = UserInput( + messages: [ + [ + "role": "user", + "content": [ + ["type": "text", "text": prompt] + ] + + images.map { _ in + ["type": "image"] + } + + videos.map { _ in + ["type": "video"] + }, + ] + ], images: images, videos: videos) userInput.processing.resize = .init(width: 448, height: 448) let input = try await context.processor.prepare(input: userInput) diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index 9ba7efeb..a7cfd123 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -249,7 +249,7 @@ private struct LLMUserInputProcessor: UserInputProcessor { // but that is not public so just fall back to text let prompt = input.prompt .asMessages() - .compactMap { $0["content"] } + .compactMap { $0["content"] as? String } .joined(separator: ". ") let promptTokens = tokenizer.encode(text: prompt) return LMInput(tokens: MLXArray(promptTokens)) diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift index d51b0443..25e900aa 100644 --- a/Libraries/MLXLMCommon/LanguageModel.swift +++ b/Libraries/MLXLMCommon/LanguageModel.swift @@ -69,14 +69,16 @@ public struct LMInput { /// Representation of prepared input image(s). public struct ProcessedImage { + /// Concatenated pixels from one or more images public let pixels: MLXArray - public let imageGridThw: [THW]? + /// Time, height, and width of the images + public let frames: [THW]? public init( - pixels: MLXArray, imageGridThw: [THW]? = nil + pixels: MLXArray, frames: [THW]? = nil ) { self.pixels = pixels - self.imageGridThw = imageGridThw + self.frames = frames } } @@ -85,13 +87,13 @@ public struct LMInput { public struct ProcessedVideo { public let pixels: MLXArray - public let videoGridThw: [THW]? + public let frames: [THW]? public init( - pixels: MLXArray, videoGridThw: [THW]? = nil + pixels: MLXArray, frames: [THW]? = nil ) { self.pixels = pixels - self.videoGridThw = videoGridThw + self.frames = frames } } diff --git a/Libraries/MLXLMCommon/UserInput.swift b/Libraries/MLXLMCommon/UserInput.swift index 8c45cd01..d2c96d20 100644 --- a/Libraries/MLXLMCommon/UserInput.swift +++ b/Libraries/MLXLMCommon/UserInput.swift @@ -6,18 +6,19 @@ import Foundation import MLX import Tokenizers +public typealias Message = [String: Any] + /// Container for raw user input. /// /// A ``UserInputProcessor`` can convert this to ``LMInput``. /// See also ``ModelContext``. public struct UserInput: Sendable { - /// Representation of a prompt or series of messages (conversation). public enum Prompt: Sendable, CustomStringConvertible { case text(String) - case messages([[String: String]]) + case messages([Message]) - public func asMessages() -> [[String: String]] { + public func asMessages() -> [Message] { switch self { case .text(let text): return [["role": "user", "content": text]] @@ -144,11 +145,13 @@ public struct UserInput: Sendable { } public init( - messages: [[String: String]], images: [Image] = [Image](), tools: [ToolSpec]? = nil, + messages: [Message], images: [Image] = [Image](), videos: [Video] = [Video](), + tools: [ToolSpec]? = nil, additionalContext: [String: Any]? = nil ) { self.prompt = .messages(messages) self.images = images + self.videos = videos self.tools = tools self.additionalContext = additionalContext } diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift index 5a73b539..9effd20d 100644 --- a/Libraries/MLXVLM/Models/Idefics3.swift +++ b/Libraries/MLXVLM/Models/Idefics3.swift @@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor { } public func prepare(input: UserInput) throws -> LMInput { - let prompt = input.prompt.asMessages().last?["content"] ?? "" + let prompt = input.prompt.asMessages().last?["content"] as? String ?? "" if input.images.isEmpty { // No image scenario diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift index a103ccb4..76cc89ca 100644 --- a/Libraries/MLXVLM/Models/Paligemma.swift +++ b/Libraries/MLXVLM/Models/Paligemma.swift @@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor { } // this doesn't have a chat template so just use the last message. - var prompt = input.prompt.asMessages().last?["content"] ?? "" + var prompt = input.prompt.asMessages().last?["content"] as? String ?? "" // based on transformers/processing_paligemma let count = input.images.count * config.imageSequenceLength diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index cb4c2af0..393db5c4 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -367,10 +367,10 @@ private enum Vision { } public func callAsFunction( - _ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray + _ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray ) -> MLXArray { let sequenceLength = x.dim(0) - let B = gridThw[0].t + let B = frames[0].t let L = sequenceLength / B let qkv = qkv(x) @@ -435,13 +435,13 @@ private enum Vision { } func callAsFunction( - _ hiddenStates: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray + _ hiddenStates: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray ) -> MLXArray { var hiddenStates = hiddenStates + attention( norm1(hiddenStates), - gridThw: gridThw, + frames: frames, rotaryPositionEmbedding: rotaryPositionEmbedding ) hiddenStates = hiddenStates + mlp(norm2(hiddenStates)) @@ -479,10 +479,10 @@ private enum Vision { spatialMergeSize: 2) } - func rotaryPositionEmbedding(_ gridThw: [THW]) -> MLXArray { + func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray { var positionIds = [MLXArray]() - for row in gridThw { + for row in frames { let (t, h, w) = row.values var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1) @@ -516,22 +516,22 @@ private enum Vision { } let indices = concatenated(positionIds, axis: 0) - let maxGridSize = gridThw.lazy.map { max($0.h, $0.w) }.max() ?? 0 - let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[ + let maxFrameSize = frames.lazy.map { max($0.h, $0.w) }.max() ?? 0 + let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxFrameSize)[ indices] return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) } - public func callAsFunction(_ hiddenStates: MLXArray, gridThw: [THW]) -> MLXArray { + public func callAsFunction(_ hiddenStates: MLXArray, frames: [THW]) -> MLXArray { var hiddenStates = patchEmbed(hiddenStates) - let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw) + let rotaryPositionEmbedding = rotaryPositionEmbedding(frames) - let batchSize = gridThw.count + let batchSize = frames.count for block in blocks { hiddenStates = block( - hiddenStates, gridThw: gridThw, + hiddenStates, frames: frames, rotaryPositionEmbedding: rotaryPositionEmbedding) } @@ -539,7 +539,7 @@ private enum Vision { } private func isMLXWeight(_ array: MLXArray) -> Bool { - if array.ndim != 4 && array.ndim != 5 { + if array.ndim != 4, array.ndim != 5 { return false } @@ -585,6 +585,10 @@ private enum Vision { /// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``. public class Qwen2VLProcessor: UserInputProcessor { + enum Qwen2VLProcessorError: Error { + case framesIsNil + } + private let config: Qwen2VLProcessorConfiguration private let tokenizer: any Tokenizer @@ -690,110 +694,96 @@ public class Qwen2VLProcessor: UserInputProcessor { return (flattenedPatches, .init(gridT, gridH, gridW)) } - public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?, videoTHW: [THW]?) -> String { - // the tokenizer does have a chat template and it expects messages - // like this: - // - // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'}, - // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}] - // - // The output of the prompt template is fed into - // image_processing_qwen2_vl.preprocess where it is further augmented - // by replacing tokens according to imageTHW. - // - // Neither the structured content nor the postprocessing of the template - // are supported in current Tokenizer/Jinja (swift) so handle that here. - - var messages = prompt.asMessages() - if messages[0]["role"] != "system" { - messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) - } - - let lastIndex = messages.count - 1 - var lastMessage = messages[lastIndex]["content"] ?? "" + public func prepare(input: UserInput) async throws -> LMInput { + let messages = input.prompt.asMessages() + var promptTokens = try tokenizer.applyChatTemplate(messages: messages) - // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image - let mergeLength = config.mergeSize * config.mergeSize - for thw in imageTHW ?? [] { - lastMessage += "<|vision_start|>" - lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength) - .joined() - lastMessage += "<|vision_end|>" + // Text-only input + if input.images.isEmpty, input.videos.isEmpty { + return LMInput(tokens: MLXArray(promptTokens)) } - for thw in videoTHW ?? [] { - lastMessage += "<|vision_start|>" - lastMessage += Array(repeating: "<|video_pad|>", count: thw.product / mergeLength) - .joined() - lastMessage += "<|vision_end|>" + // Process images if any + var processedImage: LMInput.ProcessedImage? + if !input.images.isEmpty { + let imagePixelsAndFrames = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } + let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 }) + processedImage = LMInput.ProcessedImage( + pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 }) + if let imageFrames = processedImage?.frames { + promptTokens = try replacePaddingTokens( + in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>") + } } - messages[lastIndex]["content"] = lastMessage - - return - messages - .map { - "<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>" + // Process videos if any + var processedVideo: LMInput.ProcessedVideo? + if !input.videos.isEmpty { + var videosAsImageSequences = [[CIImage]]() + for video in input.videos { + if let imageSequence = try? await MediaProcessing.asCIImageSequence( + video.asAVAsset(), samplesPerSecond: 2) + { + videosAsImageSequences.append(imageSequence) + } + } + let videoPixelsAndFrames = try videosAsImageSequences.map { + try preprocess(images: $0, processing: input.processing) + } + let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 }) + processedVideo = LMInput.ProcessedVideo( + pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 }) + if let videoFrames = processedVideo?.frames { + promptTokens = try replacePaddingTokens( + in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>") } - .joined(separator: "\n") - + "\n<|im_start|>assistant\n" - } - - public func prepare(input: UserInput) async throws -> LMInput { - if input.images.isEmpty && input.videos.isEmpty { - // just a straight text prompt - let prompt = prepare(prompt: input.prompt, imageTHW: nil, videoTHW: nil) - let promptTokens = try tokenizer.encode(text: prompt) - return LMInput(tokens: MLXArray(promptTokens)) } - // image_processing_qwen2_vl.preprocess - let images = try input.images.map { - try preprocess(images: [$0.asCIImage()], processing: input.processing) - } + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + return LMInput( + text: .init(tokens: promptArray, mask: mask), + image: processedImage, + video: processedVideo) + } - var videosAsImageSequences = [[CIImage]]() - for video in input.videos { - if let imageSequence = try? await MediaProcessing.asCIImageSequence( - video.asAVAsset(), samplesPerSecond: 2) - { - videosAsImageSequences.append(imageSequence) - } + func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String) + throws -> [Int] + { + // Replace single padding token with correct number for each image or video frame + let placeholderTokens = try tokenizer.encode( + text: "<|vision_start|>\(paddingToken)<|vision_end|>") + let placeholderRanges = promptTokens.ranges(of: placeholderTokens) + guard placeholderRanges.count == frames.count else { + throw VLMError.processing( + "Number of placeholder tokens does not match number of frames") } - let videos = try videosAsImageSequences.map { - try preprocess(images: $0, processing: input.processing) + let mergeLength = config.mergeSize * config.mergeSize + let replacementSequences = try frames.map { frame in + let paddingCount = frame.product / mergeLength + return try tokenizer.encode( + text: + "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>" + ) } - - let imagePixels: MLXArray? - let image: LMInput.ProcessedImage? - if !images.isEmpty { - imagePixels = concatenated(images.map { $0.0 }) - image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 }) - } else { - imagePixels = nil - image = nil + // Build the final array + var result: [Int] = [] + var currentIndex = promptTokens.startIndex + for (range, replacement) in zip(placeholderRanges, replacementSequences) { + // Add tokens before the placeholder + result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound]) + // Add replacement sequence + result.append(contentsOf: replacement) + currentIndex = range.upperBound } - - let videoPixels: MLXArray? - let video: LMInput.ProcessedVideo? - if !videos.isEmpty { - videoPixels = concatenated(videos.map { $0.0 }) - video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 }) - } else { - videoPixels = nil - video = nil + // Add any remaining tokens after the last replacement + if currentIndex < promptTokens.endIndex { + result.append(contentsOf: promptTokens[currentIndex...]) } - - // processing_qwen2_vl.Qwen2VLProcessor - let prompt = prepare( - prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw) - let promptTokens = try tokenizer.encode(text: prompt) - let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) - let mask = ones(like: promptArray).asType(.int8) - - return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video) + return result } - } // MARK: - Model @@ -821,10 +811,10 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) } - private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?) + private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?) -> MLXArray { - guard let pixelValues, let gridThw else { + guard let pixelValues, let frames else { return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis]) } @@ -832,7 +822,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { let inputEmbeds = languageModel.model.embedTokens(inputIds) // Get the ouptut hidden states from the vision model - var hiddenStates = self.visionModel(pixelValues, gridThw: gridThw) + var hiddenStates = self.visionModel(pixelValues, frames: frames) if hiddenStates.ndim == 2 { hiddenStates = hiddenStates[.newAxis, 0..., 0...] @@ -873,25 +863,25 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { { let dtype = visionModel.patchEmbed.proj.weight.dtype - let imageGridThw = input.image?.imageGridThw - let imagePixels = input.image?.pixels.asType(dtype) - - let videoGridThw = input.video?.videoGridThw - let videoPixels = input.video?.pixels.asType(dtype) + // Process both images and videos if present + var allPixels: [MLXArray] = [] + var allFrames: [THW] = [] - let gridThw: [THW]? - let pixels: MLXArray? + if let imagePixels = input.image?.pixels, let imageFrames = input.image?.frames { + allPixels.append(imagePixels.asType(dtype)) + allFrames.append(contentsOf: imageFrames) + } - if videoGridThw == nil { - gridThw = imageGridThw - pixels = imagePixels - } else { - gridThw = videoGridThw - pixels = videoPixels + if let videoPixels = input.video?.pixels, let videoFrames = input.video?.frames { + allPixels.append(videoPixels.asType(dtype)) + allFrames.append(contentsOf: videoFrames) } + let pixels = allPixels.isEmpty ? nil : concatenated(allPixels) + let frames = allFrames.isEmpty ? nil : allFrames + let inputEmbeddings = self.inputEmbeddings( - inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw) + inputIds: input.text.tokens, pixelValues: pixels, frames: frames) let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 3182ea85..73f654a9 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -11,6 +11,7 @@ public enum VLMError: Error { case maskRequired case singleImageAllowed case imageProcessingFailure(String) + case processing(String) } public struct BaseProcessorConfiguration: Codable, Sendable { From 1cd252b7971b215baeecbb1dfb0e2399a934f7a4 Mon Sep 17 00:00:00 2001 From: David Koski Date: Wed, 5 Feb 2025 08:04:10 -0800 Subject: [PATCH 3/6] remove duplicate entry --- .../xcshareddata/swiftpm/Package.resolved | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 5245a71d..3ed84667 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "327a4376ec20e25f941929e0bd2eefea67914f3c98414e5489f49c7e49eab7ab", + "originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c", "pins" : [ { "identity" : "gzipswift", @@ -73,15 +73,6 @@ "version" : "1.1.4" } }, - { - "identity" : "swift-collections", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-collections.git", - "state" : { - "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", - "version" : "1.1.4" - } - }, { "identity" : "swift-markdown-ui", "kind" : "remoteSourceControl", From b66f31298d1b66ccd1e45bee47f7e73395c9cb15 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Wed, 12 Feb 2025 14:35:38 +0100 Subject: [PATCH 4/6] Handle mix of images and videos --- Libraries/MLXVLM/Models/Qwen2VL.swift | 40 +++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 393db5c4..daa44c3e 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -841,21 +841,25 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { var imageIndices = [Int]() for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == imageTokenIndex { + if v == imageTokenIndex || v == videoTokenIndex { imageIndices.append(i) } } - if imageIndices.isEmpty { - for (i, v) in inputIds.asArray(Int.self).enumerated() { - if v == videoTokenIndex { - imageIndices.append(i) - } - } + // Make sure shapes match before assignment + var result = inputEmbeds + if result.ndim == 2 { + result = result[.newAxis, 0..., 0...] + } + + if imageFeatures.ndim == 2 { + let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...] + result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures + } else { + result[0..., MLXArray(imageIndices), 0...] = imageFeatures } - inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures - return inputEmbeds + return result } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws @@ -863,25 +867,27 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { { let dtype = visionModel.patchEmbed.proj.weight.dtype - // Process both images and videos if present - var allPixels: [MLXArray] = [] + // Process both images and videos together + var allPixels: MLXArray? var allFrames: [THW] = [] if let imagePixels = input.image?.pixels, let imageFrames = input.image?.frames { - allPixels.append(imagePixels.asType(dtype)) + allPixels = imagePixels.asType(dtype) allFrames.append(contentsOf: imageFrames) } if let videoPixels = input.video?.pixels, let videoFrames = input.video?.frames { - allPixels.append(videoPixels.asType(dtype)) + if allPixels == nil { + allPixels = videoPixels.asType(dtype) + } else { + allPixels = concatenated([allPixels!, videoPixels.asType(dtype)]) + } allFrames.append(contentsOf: videoFrames) } - let pixels = allPixels.isEmpty ? nil : concatenated(allPixels) - let frames = allFrames.isEmpty ? nil : allFrames - let inputEmbeddings = self.inputEmbeddings( - inputIds: input.text.tokens, pixelValues: pixels, frames: frames) + inputIds: input.text.tokens, pixelValues: allPixels, + frames: allFrames.isEmpty ? nil : allFrames) let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) From 908e3e2bda38f6b61cfe88cd3e94e1c876edc4c6 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Wed, 12 Feb 2025 17:54:49 +0100 Subject: [PATCH 5/6] Remove Qwen2VLProcessorError --- Libraries/MLXVLM/Models/Qwen2VL.swift | 5 ----- 1 file changed, 5 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index daa44c3e..f71e2352 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -584,11 +584,6 @@ private enum Vision { /// /// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``. public class Qwen2VLProcessor: UserInputProcessor { - - enum Qwen2VLProcessorError: Error { - case framesIsNil - } - private let config: Qwen2VLProcessorConfiguration private let tokenizer: any Tokenizer From 5532c5801394b029035e4696d136ef6bda88da3b Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Wed, 12 Feb 2025 19:03:08 +0100 Subject: [PATCH 6/6] Support images and videos in llm-tool --- Tools/llm-tool/LLMTool.swift | 41 ++++++++++++++----- .../xcshareddata/xcschemes/llm-tool.xcscheme | 6 ++- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index 13dc0ebd..648ee663 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -22,7 +22,7 @@ struct LLMTool: AsyncParsableCommand { /// Command line arguments for loading a model. struct ModelArguments: ParsableArguments, Sendable { - @Option(name: .long, help: "Name of the huggingface model or absolute path to directory") + @Option(name: .long, help: "Name of the Hugging Face model or absolute path to directory") var model: String? @Sendable @@ -194,7 +194,6 @@ struct MemoryArguments: ParsableArguments, Sendable { } struct EvaluateCommand: AsyncParsableCommand { - static let configuration = CommandConfiguration( commandName: "eval", abstract: "evaluate prompt and generate text" @@ -207,22 +206,42 @@ struct EvaluateCommand: AsyncParsableCommand { @Option(parsing: .upToNextOption, help: "Resize images to this size (width, height)") var resize: [Int] = [] - @Option(parsing: .upToNextOption, help: "Paths or urls for input images") + @Option(parsing: .upToNextOption, help: "Paths or URLs for input images") var image: [URL] = [] + @Option(parsing: .upToNextOption, help: "Paths or URLs for input videos") + var video: [URL] = [] + private func userInput(modelConfiguration: ModelConfiguration) -> UserInput { - // prompt and images let prompt = (try? generate.resolvePrompt(configuration: modelConfiguration)) ?? modelConfiguration.defaultPrompt + let images = image.map { UserInput.Image.url($0) } - var input = UserInput(prompt: prompt, images: images) + let videos = video.map { UserInput.Video.url($0) } + + let messages: [[String: Any]] = [ + [ + "role": "user", + "content": [ + ["type": "text", "text": prompt] + ] + // Messages format for Qwen 2 VL, Qwen 2.5 VL. May need to be adapted for other models. + + images.map { _ in ["type": "image"] } + + videos.map { _ in ["type": "video"] }, + ] + ] + + var input = UserInput( + messages: messages, + images: images, + videos: videos + ) - // processing instructions if !resize.isEmpty { let size: CGSize if resize.count == 1 { - // single value represents width/height + // Single value represents width/height let v = resize[0] size = CGSize(width: v, height: v) } else { @@ -241,8 +260,8 @@ struct EvaluateCommand: AsyncParsableCommand { let modelFactory: ModelFactory let defaultModel: ModelConfiguration - // switch between LLM and VLM - let vlm = image.count > 0 + // Switch between LLM and VLM based on presence of media + let vlm = !image.isEmpty || !video.isEmpty if vlm { modelFactory = VLMModelFactory.shared defaultModel = MLXVLM.ModelRegistry.qwen2VL2BInstruct4Bit @@ -251,12 +270,12 @@ struct EvaluateCommand: AsyncParsableCommand { defaultModel = MLXLLM.ModelRegistry.mistral7B4bit } - // load the model + // Load the model let modelContainer = try await memory.start { [args] in try await args.load(defaultModel: defaultModel.name, modelFactory: modelFactory) } - // get the resolved configuration (this has the default prompt) + // Get the resolved configuration (this has the default prompt) let modelConfiguration = modelContainer.configuration if !generate.quiet { diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme index d7782987..11c3173f 100644 --- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme +++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme @@ -77,10 +77,14 @@ + isEnabled = "NO"> + +