Skip to content

Commit c743526

Browse files
Use chat templates for vision models (#173)
* Update swift-transformers * Use chat template for Qwen 2 VL * remove duplicate entry * Handle mix of images and videos * Remove Qwen2VLProcessorError * Support images and videos in llm-tool
1 parent 983eaac commit c743526

File tree

11 files changed

+203
-159
lines changed

11 files changed

+203
-159
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,33 @@ class VLMEvaluator {
383383
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
384384

385385
let result = try await modelContainer.perform { context in
386-
let images: [UserInput.Image] = image != nil ? [.ciImage(image!)] : []
387-
let videos: [UserInput.Video] = videoURL != nil ? [.url(videoURL!)] : []
388-
var userInput = UserInput(prompt: prompt, images: images, videos: videos)
386+
let images: [UserInput.Image] =
387+
if let image {
388+
[UserInput.Image.ciImage(image)]
389+
} else {
390+
[]
391+
}
392+
let videos: [UserInput.Video] =
393+
if let videoURL {
394+
[.url(videoURL)]
395+
} else {
396+
[]
397+
}
398+
var userInput = UserInput(
399+
messages: [
400+
[
401+
"role": "user",
402+
"content": [
403+
["type": "text", "text": prompt]
404+
]
405+
+ images.map { _ in
406+
["type": "image"]
407+
}
408+
+ videos.map { _ in
409+
["type": "video"]
410+
},
411+
]
412+
], images: images, videos: videos)
389413
userInput.processing.resize = .init(width: 448, height: 448)
390414

391415
let input = try await context.processor.prepare(input: userInput)

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
249249
// but that is not public so just fall back to text
250250
let prompt = input.prompt
251251
.asMessages()
252-
.compactMap { $0["content"] }
252+
.compactMap { $0["content"] as? String }
253253
.joined(separator: ". ")
254254
let promptTokens = tokenizer.encode(text: prompt)
255255
return LMInput(tokens: MLXArray(promptTokens))

Libraries/MLXLMCommon/LanguageModel.swift

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,16 @@ public struct LMInput {
6969
/// Representation of prepared input image(s).
7070
public struct ProcessedImage {
7171

72+
/// Concatenated pixels from one or more images
7273
public let pixels: MLXArray
73-
public let imageGridThw: [THW]?
74+
/// Time, height, and width of the images
75+
public let frames: [THW]?
7476

7577
public init(
76-
pixels: MLXArray, imageGridThw: [THW]? = nil
78+
pixels: MLXArray, frames: [THW]? = nil
7779
) {
7880
self.pixels = pixels
79-
self.imageGridThw = imageGridThw
81+
self.frames = frames
8082
}
8183
}
8284

@@ -85,13 +87,13 @@ public struct LMInput {
8587
public struct ProcessedVideo {
8688

8789
public let pixels: MLXArray
88-
public let videoGridThw: [THW]?
90+
public let frames: [THW]?
8991

9092
public init(
91-
pixels: MLXArray, videoGridThw: [THW]? = nil
93+
pixels: MLXArray, frames: [THW]? = nil
9294
) {
9395
self.pixels = pixels
94-
self.videoGridThw = videoGridThw
96+
self.frames = frames
9597
}
9698
}
9799

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@ import Foundation
66
import MLX
77
import Tokenizers
88

9+
public typealias Message = [String: Any]
10+
911
/// Container for raw user input.
1012
///
1113
/// A ``UserInputProcessor`` can convert this to ``LMInput``.
1214
/// See also ``ModelContext``.
1315
public struct UserInput: Sendable {
14-
1516
/// Representation of a prompt or series of messages (conversation).
1617
public enum Prompt: Sendable, CustomStringConvertible {
1718
case text(String)
18-
case messages([[String: String]])
19+
case messages([Message])
1920

20-
public func asMessages() -> [[String: String]] {
21+
public func asMessages() -> [Message] {
2122
switch self {
2223
case .text(let text):
2324
return [["role": "user", "content": text]]
@@ -144,11 +145,13 @@ public struct UserInput: Sendable {
144145
}
145146

146147
public init(
147-
messages: [[String: String]], images: [Image] = [Image](), tools: [ToolSpec]? = nil,
148+
messages: [Message], images: [Image] = [Image](), videos: [Video] = [Video](),
149+
tools: [ToolSpec]? = nil,
148150
additionalContext: [String: Any]? = nil
149151
) {
150152
self.prompt = .messages(messages)
151153
self.images = images
154+
self.videos = videos
152155
self.tools = tools
153156
self.additionalContext = additionalContext
154157
}

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor {
805805
}
806806

807807
public func prepare(input: UserInput) throws -> LMInput {
808-
let prompt = input.prompt.asMessages().last?["content"] ?? ""
808+
let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
809809

810810
if input.images.isEmpty {
811811
// No image scenario

Libraries/MLXVLM/Models/Paligemma.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
478478
}
479479

480480
// this doesn't have a chat template so just use the last message.
481-
var prompt = input.prompt.asMessages().last?["content"] ?? ""
481+
var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
482482

483483
// based on transformers/processing_paligemma
484484
let count = input.images.count * config.imageSequenceLength

0 commit comments

Comments
 (0)