diff --git a/Applications/VLMEval/ContentView.swift b/Applications/VLMEval/ContentView.swift
index 33b4debe..041fa097 100644
--- a/Applications/VLMEval/ContentView.swift
+++ b/Applications/VLMEval/ContentView.swift
@@ -383,9 +383,33 @@ class VLMEvaluator {
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
let result = try await modelContainer.perform { context in
- let images: [UserInput.Image] = image != nil ? [.ciImage(image!)] : []
- let videos: [UserInput.Video] = videoURL != nil ? [.url(videoURL!)] : []
- var userInput = UserInput(prompt: prompt, images: images, videos: videos)
+ let images: [UserInput.Image] =
+ if let image {
+ [UserInput.Image.ciImage(image)]
+ } else {
+ []
+ }
+ let videos: [UserInput.Video] =
+ if let videoURL {
+ [.url(videoURL)]
+ } else {
+ []
+ }
+ var userInput = UserInput(
+ messages: [
+ [
+ "role": "user",
+ "content": [
+ ["type": "text", "text": prompt]
+ ]
+ + images.map { _ in
+ ["type": "image"]
+ }
+ + videos.map { _ in
+ ["type": "video"]
+ },
+ ]
+ ], images: images, videos: videos)
userInput.processing.resize = .init(width: 448, height: 448)
let input = try await context.processor.prepare(input: userInput)
diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift
index 9ba7efeb..a7cfd123 100644
--- a/Libraries/MLXLLM/LLMModelFactory.swift
+++ b/Libraries/MLXLLM/LLMModelFactory.swift
@@ -249,7 +249,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
// but that is not public so just fall back to text
let prompt = input.prompt
.asMessages()
- .compactMap { $0["content"] }
+ .compactMap { $0["content"] as? String }
.joined(separator: ". ")
let promptTokens = tokenizer.encode(text: prompt)
return LMInput(tokens: MLXArray(promptTokens))
diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift
index d51b0443..25e900aa 100644
--- a/Libraries/MLXLMCommon/LanguageModel.swift
+++ b/Libraries/MLXLMCommon/LanguageModel.swift
@@ -69,14 +69,16 @@ public struct LMInput {
/// Representation of prepared input image(s).
public struct ProcessedImage {
+ /// Concatenated pixels from one or more images
public let pixels: MLXArray
- public let imageGridThw: [THW]?
+ /// Time, height, and width of the images
+ public let frames: [THW]?
public init(
- pixels: MLXArray, imageGridThw: [THW]? = nil
+ pixels: MLXArray, frames: [THW]? = nil
) {
self.pixels = pixels
- self.imageGridThw = imageGridThw
+ self.frames = frames
}
}
@@ -85,13 +87,13 @@ public struct LMInput {
public struct ProcessedVideo {
public let pixels: MLXArray
- public let videoGridThw: [THW]?
+ public let frames: [THW]?
public init(
- pixels: MLXArray, videoGridThw: [THW]? = nil
+ pixels: MLXArray, frames: [THW]? = nil
) {
self.pixels = pixels
- self.videoGridThw = videoGridThw
+ self.frames = frames
}
}
diff --git a/Libraries/MLXLMCommon/UserInput.swift b/Libraries/MLXLMCommon/UserInput.swift
index 8c45cd01..d2c96d20 100644
--- a/Libraries/MLXLMCommon/UserInput.swift
+++ b/Libraries/MLXLMCommon/UserInput.swift
@@ -6,18 +6,19 @@ import Foundation
import MLX
import Tokenizers
+public typealias Message = [String: Any]
+
/// Container for raw user input.
///
/// A ``UserInputProcessor`` can convert this to ``LMInput``.
/// See also ``ModelContext``.
public struct UserInput: Sendable {
-
/// Representation of a prompt or series of messages (conversation).
public enum Prompt: Sendable, CustomStringConvertible {
case text(String)
- case messages([[String: String]])
+ case messages([Message])
- public func asMessages() -> [[String: String]] {
+ public func asMessages() -> [Message] {
switch self {
case .text(let text):
return [["role": "user", "content": text]]
@@ -144,11 +145,13 @@ public struct UserInput: Sendable {
}
public init(
- messages: [[String: String]], images: [Image] = [Image](), tools: [ToolSpec]? = nil,
+ messages: [Message], images: [Image] = [Image](), videos: [Video] = [Video](),
+ tools: [ToolSpec]? = nil,
additionalContext: [String: Any]? = nil
) {
self.prompt = .messages(messages)
self.images = images
+ self.videos = videos
self.tools = tools
self.additionalContext = additionalContext
}
diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift
index 5a73b539..9effd20d 100644
--- a/Libraries/MLXVLM/Models/Idefics3.swift
+++ b/Libraries/MLXVLM/Models/Idefics3.swift
@@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor {
}
public func prepare(input: UserInput) throws -> LMInput {
- let prompt = input.prompt.asMessages().last?["content"] ?? ""
+ let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
if input.images.isEmpty {
// No image scenario
diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift
index a103ccb4..76cc89ca 100644
--- a/Libraries/MLXVLM/Models/Paligemma.swift
+++ b/Libraries/MLXVLM/Models/Paligemma.swift
@@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
}
// this doesn't have a chat template so just use the last message.
- var prompt = input.prompt.asMessages().last?["content"] ?? ""
+ var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
// based on transformers/processing_paligemma
let count = input.images.count * config.imageSequenceLength
diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift
index cb4c2af0..f71e2352 100644
--- a/Libraries/MLXVLM/Models/Qwen2VL.swift
+++ b/Libraries/MLXVLM/Models/Qwen2VL.swift
@@ -367,10 +367,10 @@ private enum Vision {
}
public func callAsFunction(
- _ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
+ _ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray
) -> MLXArray {
let sequenceLength = x.dim(0)
- let B = gridThw[0].t
+ let B = frames[0].t
let L = sequenceLength / B
let qkv = qkv(x)
@@ -435,13 +435,13 @@ private enum Vision {
}
func callAsFunction(
- _ hiddenStates: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
+ _ hiddenStates: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray
) -> MLXArray {
var hiddenStates =
hiddenStates
+ attention(
norm1(hiddenStates),
- gridThw: gridThw,
+ frames: frames,
rotaryPositionEmbedding: rotaryPositionEmbedding
)
hiddenStates = hiddenStates + mlp(norm2(hiddenStates))
@@ -479,10 +479,10 @@ private enum Vision {
spatialMergeSize: 2)
}
- func rotaryPositionEmbedding(_ gridThw: [THW]) -> MLXArray {
+ func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray {
var positionIds = [MLXArray]()
- for row in gridThw {
+ for row in frames {
let (t, h, w) = row.values
var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1)
@@ -516,22 +516,22 @@ private enum Vision {
}
let indices = concatenated(positionIds, axis: 0)
- let maxGridSize = gridThw.lazy.map { max($0.h, $0.w) }.max() ?? 0
- let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[
+ let maxFrameSize = frames.lazy.map { max($0.h, $0.w) }.max() ?? 0
+ let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxFrameSize)[
indices]
return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1)
}
- public func callAsFunction(_ hiddenStates: MLXArray, gridThw: [THW]) -> MLXArray {
+ public func callAsFunction(_ hiddenStates: MLXArray, frames: [THW]) -> MLXArray {
var hiddenStates = patchEmbed(hiddenStates)
- let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw)
+ let rotaryPositionEmbedding = rotaryPositionEmbedding(frames)
- let batchSize = gridThw.count
+ let batchSize = frames.count
for block in blocks {
hiddenStates = block(
- hiddenStates, gridThw: gridThw,
+ hiddenStates, frames: frames,
rotaryPositionEmbedding: rotaryPositionEmbedding)
}
@@ -539,7 +539,7 @@ private enum Vision {
}
private func isMLXWeight(_ array: MLXArray) -> Bool {
- if array.ndim != 4 && array.ndim != 5 {
+ if array.ndim != 4, array.ndim != 5 {
return false
}
@@ -584,7 +584,6 @@ private enum Vision {
///
/// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``.
public class Qwen2VLProcessor: UserInputProcessor {
-
private let config: Qwen2VLProcessorConfiguration
private let tokenizer: any Tokenizer
@@ -690,110 +689,96 @@ public class Qwen2VLProcessor: UserInputProcessor {
return (flattenedPatches, .init(gridT, gridH, gridW))
}
- public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?, videoTHW: [THW]?) -> String {
- // the tokenizer does have a chat template and it expects messages
- // like this:
- //
- // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
- // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
- //
- // The output of the prompt template is fed into
- // image_processing_qwen2_vl.preprocess where it is further augmented
- // by replacing tokens according to imageTHW.
- //
- // Neither the structured content nor the postprocessing of the template
- // are supported in current Tokenizer/Jinja (swift) so handle that here.
-
- var messages = prompt.asMessages()
- if messages[0]["role"] != "system" {
- messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
- }
-
- let lastIndex = messages.count - 1
- var lastMessage = messages[lastIndex]["content"] ?? ""
+ public func prepare(input: UserInput) async throws -> LMInput {
+ let messages = input.prompt.asMessages()
+ var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
- // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
- let mergeLength = config.mergeSize * config.mergeSize
- for thw in imageTHW ?? [] {
- lastMessage += "<|vision_start|>"
- lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
- .joined()
- lastMessage += "<|vision_end|>"
+ // Text-only input
+ if input.images.isEmpty, input.videos.isEmpty {
+ return LMInput(tokens: MLXArray(promptTokens))
}
- for thw in videoTHW ?? [] {
- lastMessage += "<|vision_start|>"
- lastMessage += Array(repeating: "<|video_pad|>", count: thw.product / mergeLength)
- .joined()
- lastMessage += "<|vision_end|>"
+ // Process images if any
+ var processedImage: LMInput.ProcessedImage?
+ if !input.images.isEmpty {
+ let imagePixelsAndFrames = try input.images.map {
+ try preprocess(images: [$0.asCIImage()], processing: input.processing)
+ }
+ let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 })
+ processedImage = LMInput.ProcessedImage(
+ pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 })
+ if let imageFrames = processedImage?.frames {
+ promptTokens = try replacePaddingTokens(
+ in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>")
+ }
}
- messages[lastIndex]["content"] = lastMessage
-
- return
- messages
- .map {
- "<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
+ // Process videos if any
+ var processedVideo: LMInput.ProcessedVideo?
+ if !input.videos.isEmpty {
+ var videosAsImageSequences = [[CIImage]]()
+ for video in input.videos {
+ if let imageSequence = try? await MediaProcessing.asCIImageSequence(
+ video.asAVAsset(), samplesPerSecond: 2)
+ {
+ videosAsImageSequences.append(imageSequence)
+ }
+ }
+ let videoPixelsAndFrames = try videosAsImageSequences.map {
+ try preprocess(images: $0, processing: input.processing)
+ }
+ let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 })
+ processedVideo = LMInput.ProcessedVideo(
+ pixels: videoPixelsConcatenated, frames: videoPixelsAndFrames.map { $0.1 })
+ if let videoFrames = processedVideo?.frames {
+ promptTokens = try replacePaddingTokens(
+ in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>")
}
- .joined(separator: "\n")
- + "\n<|im_start|>assistant\n"
- }
-
- public func prepare(input: UserInput) async throws -> LMInput {
- if input.images.isEmpty && input.videos.isEmpty {
- // just a straight text prompt
- let prompt = prepare(prompt: input.prompt, imageTHW: nil, videoTHW: nil)
- let promptTokens = try tokenizer.encode(text: prompt)
- return LMInput(tokens: MLXArray(promptTokens))
}
- // image_processing_qwen2_vl.preprocess
- let images = try input.images.map {
- try preprocess(images: [$0.asCIImage()], processing: input.processing)
- }
+ let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
+ let mask = ones(like: promptArray).asType(.int8)
+ return LMInput(
+ text: .init(tokens: promptArray, mask: mask),
+ image: processedImage,
+ video: processedVideo)
+ }
- var videosAsImageSequences = [[CIImage]]()
- for video in input.videos {
- if let imageSequence = try? await MediaProcessing.asCIImageSequence(
- video.asAVAsset(), samplesPerSecond: 2)
- {
- videosAsImageSequences.append(imageSequence)
- }
+ func replacePaddingTokens(in promptTokens: [Int], frames: [THW], paddingToken: String)
+ throws -> [Int]
+ {
+ // Replace single padding token with correct number for each image or video frame
+ let placeholderTokens = try tokenizer.encode(
+ text: "<|vision_start|>\(paddingToken)<|vision_end|>")
+ let placeholderRanges = promptTokens.ranges(of: placeholderTokens)
+ guard placeholderRanges.count == frames.count else {
+ throw VLMError.processing(
+ "Number of placeholder tokens does not match number of frames")
}
- let videos = try videosAsImageSequences.map {
- try preprocess(images: $0, processing: input.processing)
+ let mergeLength = config.mergeSize * config.mergeSize
+ let replacementSequences = try frames.map { frame in
+ let paddingCount = frame.product / mergeLength
+ return try tokenizer.encode(
+ text:
+ "<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>"
+ )
}
-
- let imagePixels: MLXArray?
- let image: LMInput.ProcessedImage?
- if !images.isEmpty {
- imagePixels = concatenated(images.map { $0.0 })
- image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 })
- } else {
- imagePixels = nil
- image = nil
+ // Build the final array
+ var result: [Int] = []
+ var currentIndex = promptTokens.startIndex
+ for (range, replacement) in zip(placeholderRanges, replacementSequences) {
+ // Add tokens before the placeholder
+ result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound])
+ // Add replacement sequence
+ result.append(contentsOf: replacement)
+ currentIndex = range.upperBound
}
-
- let videoPixels: MLXArray?
- let video: LMInput.ProcessedVideo?
- if !videos.isEmpty {
- videoPixels = concatenated(videos.map { $0.0 })
- video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 })
- } else {
- videoPixels = nil
- video = nil
+ // Add any remaining tokens after the last replacement
+ if currentIndex < promptTokens.endIndex {
+ result.append(contentsOf: promptTokens[currentIndex...])
}
-
- // processing_qwen2_vl.Qwen2VLProcessor
- let prompt = prepare(
- prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw)
- let promptTokens = try tokenizer.encode(text: prompt)
- let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
- let mask = ones(like: promptArray).asType(.int8)
-
- return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video)
+ return result
}
-
}
// MARK: - Model
@@ -821,10 +806,10 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration)
}
- private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?)
+ private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?)
-> MLXArray
{
- guard let pixelValues, let gridThw else {
+ guard let pixelValues, let frames else {
return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis])
}
@@ -832,7 +817,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
let inputEmbeds = languageModel.model.embedTokens(inputIds)
// Get the ouptut hidden states from the vision model
- var hiddenStates = self.visionModel(pixelValues, gridThw: gridThw)
+ var hiddenStates = self.visionModel(pixelValues, frames: frames)
if hiddenStates.ndim == 2 {
hiddenStates = hiddenStates[.newAxis, 0..., 0...]
@@ -851,21 +836,25 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
var imageIndices = [Int]()
for (i, v) in inputIds.asArray(Int.self).enumerated() {
- if v == imageTokenIndex {
+ if v == imageTokenIndex || v == videoTokenIndex {
imageIndices.append(i)
}
}
- if imageIndices.isEmpty {
- for (i, v) in inputIds.asArray(Int.self).enumerated() {
- if v == videoTokenIndex {
- imageIndices.append(i)
- }
- }
+ // Make sure shapes match before assignment
+ var result = inputEmbeds
+ if result.ndim == 2 {
+ result = result[.newAxis, 0..., 0...]
+ }
+
+ if imageFeatures.ndim == 2 {
+ let reshapedFeatures = imageFeatures[.newAxis, 0..., 0...]
+ result[0..., MLXArray(imageIndices), 0...] = reshapedFeatures
+ } else {
+ result[0..., MLXArray(imageIndices), 0...] = imageFeatures
}
- inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures
- return inputEmbeds
+ return result
}
public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws
@@ -873,25 +862,27 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
{
let dtype = visionModel.patchEmbed.proj.weight.dtype
- let imageGridThw = input.image?.imageGridThw
- let imagePixels = input.image?.pixels.asType(dtype)
+ // Process both images and videos together
+ var allPixels: MLXArray?
+ var allFrames: [THW] = []
- let videoGridThw = input.video?.videoGridThw
- let videoPixels = input.video?.pixels.asType(dtype)
-
- let gridThw: [THW]?
- let pixels: MLXArray?
+ if let imagePixels = input.image?.pixels, let imageFrames = input.image?.frames {
+ allPixels = imagePixels.asType(dtype)
+ allFrames.append(contentsOf: imageFrames)
+ }
- if videoGridThw == nil {
- gridThw = imageGridThw
- pixels = imagePixels
- } else {
- gridThw = videoGridThw
- pixels = videoPixels
+ if let videoPixels = input.video?.pixels, let videoFrames = input.video?.frames {
+ if allPixels == nil {
+ allPixels = videoPixels.asType(dtype)
+ } else {
+ allPixels = concatenated([allPixels!, videoPixels.asType(dtype)])
+ }
+ allFrames.append(contentsOf: videoFrames)
}
let inputEmbeddings = self.inputEmbeddings(
- inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw)
+ inputIds: input.text.tokens, pixelValues: allPixels,
+ frames: allFrames.isEmpty ? nil : allFrames)
let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings)
diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift
index 3182ea85..73f654a9 100644
--- a/Libraries/MLXVLM/VLMModelFactory.swift
+++ b/Libraries/MLXVLM/VLMModelFactory.swift
@@ -11,6 +11,7 @@ public enum VLMError: Error {
case maskRequired
case singleImageAllowed
case imageProcessingFailure(String)
+ case processing(String)
}
public struct BaseProcessorConfiguration: Codable, Sendable {
diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift
index 13dc0ebd..648ee663 100644
--- a/Tools/llm-tool/LLMTool.swift
+++ b/Tools/llm-tool/LLMTool.swift
@@ -22,7 +22,7 @@ struct LLMTool: AsyncParsableCommand {
/// Command line arguments for loading a model.
struct ModelArguments: ParsableArguments, Sendable {
- @Option(name: .long, help: "Name of the huggingface model or absolute path to directory")
+ @Option(name: .long, help: "Name of the Hugging Face model or absolute path to directory")
var model: String?
@Sendable
@@ -194,7 +194,6 @@ struct MemoryArguments: ParsableArguments, Sendable {
}
struct EvaluateCommand: AsyncParsableCommand {
-
static let configuration = CommandConfiguration(
commandName: "eval",
abstract: "evaluate prompt and generate text"
@@ -207,22 +206,42 @@ struct EvaluateCommand: AsyncParsableCommand {
@Option(parsing: .upToNextOption, help: "Resize images to this size (width, height)")
var resize: [Int] = []
- @Option(parsing: .upToNextOption, help: "Paths or urls for input images")
+ @Option(parsing: .upToNextOption, help: "Paths or URLs for input images")
var image: [URL] = []
+ @Option(parsing: .upToNextOption, help: "Paths or URLs for input videos")
+ var video: [URL] = []
+
private func userInput(modelConfiguration: ModelConfiguration) -> UserInput {
- // prompt and images
let prompt =
(try? generate.resolvePrompt(configuration: modelConfiguration))
?? modelConfiguration.defaultPrompt
+
let images = image.map { UserInput.Image.url($0) }
- var input = UserInput(prompt: prompt, images: images)
+ let videos = video.map { UserInput.Video.url($0) }
+
+ let messages: [[String: Any]] = [
+ [
+ "role": "user",
+ "content": [
+ ["type": "text", "text": prompt]
+ ]
+ // Messages format for Qwen 2 VL, Qwen 2.5 VL. May need to be adapted for other models.
+ + images.map { _ in ["type": "image"] }
+ + videos.map { _ in ["type": "video"] },
+ ]
+ ]
+
+ var input = UserInput(
+ messages: messages,
+ images: images,
+ videos: videos
+ )
- // processing instructions
if !resize.isEmpty {
let size: CGSize
if resize.count == 1 {
- // single value represents width/height
+ // Single value represents width/height
let v = resize[0]
size = CGSize(width: v, height: v)
} else {
@@ -241,8 +260,8 @@ struct EvaluateCommand: AsyncParsableCommand {
let modelFactory: ModelFactory
let defaultModel: ModelConfiguration
- // switch between LLM and VLM
- let vlm = image.count > 0
+ // Switch between LLM and VLM based on presence of media
+ let vlm = !image.isEmpty || !video.isEmpty
if vlm {
modelFactory = VLMModelFactory.shared
defaultModel = MLXVLM.ModelRegistry.qwen2VL2BInstruct4Bit
@@ -251,12 +270,12 @@ struct EvaluateCommand: AsyncParsableCommand {
defaultModel = MLXLLM.ModelRegistry.mistral7B4bit
}
- // load the model
+ // Load the model
let modelContainer = try await memory.start { [args] in
try await args.load(defaultModel: defaultModel.name, modelFactory: modelFactory)
}
- // get the resolved configuration (this has the default prompt)
+ // Get the resolved configuration (this has the default prompt)
let modelConfiguration = modelContainer.configuration
if !generate.quiet {
diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
index 8f722eba..3ed84667 100644
--- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
+++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
@@ -1,5 +1,5 @@
{
- "originHash" : "327a4376ec20e25f941929e0bd2eefea67914f3c98414e5489f49c7e49eab7ab",
+ "originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c",
"pins" : [
{
"identity" : "gzipswift",
diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
index d7782987..11c3173f 100644
--- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
+++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
@@ -77,10 +77,14 @@
+ isEnabled = "NO">
+
+