Skip to content

Commit 96cdc14

Browse files
committed
Working on a solution
1 parent 4547cf1 commit 96cdc14

File tree

5 files changed

+32
-21
lines changed

5 files changed

+32
-21
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,16 @@ class VLMEvaluator {
331331
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
332332

333333
let result = try await modelContainer.perform { context in
334-
var userInput = UserInput(prompt: prompt, images: [.ciImage(image)])
334+
var userInput = UserInput(
335+
messages: [
336+
[
337+
"role": "user",
338+
"content": [
339+
["type": "text", "text": prompt],
340+
["type": "image"],
341+
],
342+
]
343+
], images: [.ciImage(image)])
335344
userInput.processing.resize = .init(width: 448, height: 448)
336345

337346
let input = try await context.processor.prepare(input: userInput)

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ import CoreImage
44
import Foundation
55
import MLX
66

7+
public typealias Message = [String: Any]
8+
79
/// Container for raw user input.
810
///
911
/// A ``UserInputProcessor`` can convert this to ``LMInput``.
1012
/// See also ``ModelContext``.
1113
public struct UserInput: Sendable {
12-
1314
/// Representation of a prompt or series of messages (conversation).
1415
public enum Prompt: Sendable, CustomStringConvertible {
1516
case text(String)
16-
case messages([[String: String]])
17+
case messages([Message])
1718

18-
public func asMessages() -> [[String: String]] {
19+
public func asMessages() -> [Message] {
1920
switch self {
2021
case .text(let text):
2122
return [["role": "user", "content": text]]
@@ -116,7 +117,7 @@ public struct UserInput: Sendable {
116117
self.images = images
117118
}
118119

119-
public init(messages: [[String: String]], images: [Image] = [Image]()) {
120+
public init(messages: [Message], images: [Image] = [Image]()) {
120121
self.prompt = .messages(messages)
121122
self.images = images
122123
}

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor {
805805
}
806806

807807
public func prepare(input: UserInput) throws -> LMInput {
808-
let prompt = input.prompt.asMessages().last?["content"] ?? ""
808+
let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
809809

810810
if input.images.isEmpty {
811811
// No image scenario

Libraries/MLXVLM/Models/Paligemma.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
478478
}
479479

480480
// this doesn't have a chat template so just use the last message.
481-
var prompt = input.prompt.asMessages().last?["content"] ?? ""
481+
var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
482482

483483
// based on transformers/processing_paligemma
484484
let count = input.images.count * config.imageSequenceLength

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -686,24 +686,25 @@ public class Qwen2VLProcessor: UserInputProcessor {
686686
return (flattenedPatches, .init(gridT, gridH, gridW))
687687
}
688688

689-
private func prepareMessages(_ messages: [[String: String]], imageTHW: [THW]?) -> [[String: String]] {
689+
private func prepareMessages(_ messages: [Message], imageTHW: [THW]?) -> [Message] {
690690
var messages = messages
691+
print(messages)
691692
// Add system message if not present
692-
if messages[0]["role"] != "system" {
693+
if let role = messages[0]["role"] as? String, role != "system" {
693694
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
694695
}
695-
// Add image markers to last message if needed
696-
if let imageTHW {
697-
let lastIndex = messages.count - 1
698-
var content = messages[lastIndex]["content"] ?? ""
699-
let mergeLength = config.mergeSize * config.mergeSize
700-
for thw in imageTHW {
701-
content += "<|vision_start|>"
702-
content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
703-
content += "<|vision_end|>"
704-
}
705-
messages[lastIndex]["content"] = content
706-
}
696+
// // Add image markers to last message if needed
697+
// if let imageTHW {
698+
// let lastIndex = messages.count - 1
699+
// var content = messages[lastIndex]["content"] ?? ""
700+
// let mergeLength = config.mergeSize * config.mergeSize
701+
// for thw in imageTHW {
702+
// content += "<|vision_start|>"
703+
// content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
704+
// content += "<|vision_end|>"
705+
// }
706+
// messages[lastIndex]["content"] = content
707+
// }
707708
return messages
708709
}
709710

0 commit comments

Comments
 (0)