Skip to content

Commit 0f68fd2

Browse files
committed
Use chat template for Qwen 2 VL
1 parent 8cdc4a0 commit 0f68fd2

File tree

7 files changed

+91
-65
lines changed

7 files changed

+91
-65
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,16 @@ class VLMEvaluator {
331331
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
332332

333333
let result = try await modelContainer.perform { context in
334-
var userInput = UserInput(prompt: prompt, images: [.ciImage(image)])
334+
var userInput = UserInput(
335+
messages: [
336+
[
337+
"role": "user",
338+
"content": [
339+
["type": "text", "text": prompt],
340+
["type": "image"],
341+
],
342+
]
343+
], images: [.ciImage(image)])
335344
userInput.processing.resize = .init(width: 448, height: 448)
336345

337346
let input = try await context.processor.prepare(input: userInput)

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ import CoreImage
44
import Foundation
55
import MLX
66

7+
public typealias Message = [String: Any]
8+
79
/// Container for raw user input.
810
///
911
/// A ``UserInputProcessor`` can convert this to ``LMInput``.
1012
/// See also ``ModelContext``.
1113
public struct UserInput: Sendable {
12-
1314
/// Representation of a prompt or series of messages (conversation).
1415
public enum Prompt: Sendable, CustomStringConvertible {
1516
case text(String)
16-
case messages([[String: String]])
17+
case messages([Message])
1718

18-
public func asMessages() -> [[String: String]] {
19+
public func asMessages() -> [Message] {
1920
switch self {
2021
case .text(let text):
2122
return [["role": "user", "content": text]]
@@ -116,7 +117,7 @@ public struct UserInput: Sendable {
116117
self.images = images
117118
}
118119

119-
public init(messages: [[String: String]], images: [Image] = [Image]()) {
120+
public init(messages: [Message], images: [Image] = [Image]()) {
120121
self.prompt = .messages(messages)
121122
self.images = images
122123
}

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor {
805805
}
806806

807807
public func prepare(input: UserInput) throws -> LMInput {
808-
let prompt = input.prompt.asMessages().last?["content"] ?? ""
808+
let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
809809

810810
if input.images.isEmpty {
811811
// No image scenario

Libraries/MLXVLM/Models/Paligemma.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
478478
}
479479

480480
// this doesn't have a chat template so just use the last message.
481-
var prompt = input.prompt.asMessages().last?["content"] ?? ""
481+
var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
482482

483483
// based on transformers/processing_paligemma
484484
let count = input.images.count * config.imageSequenceLength

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -686,69 +686,65 @@ public class Qwen2VLProcessor: UserInputProcessor {
686686
return (flattenedPatches, .init(gridT, gridH, gridW))
687687
}
688688

689-
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) -> String {
690-
// the tokenizer does have a chat template and it expects messages
691-
// like this:
692-
//
693-
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
694-
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
695-
//
696-
// The output of the prompt template is fed into
697-
// image_processing_qwen2_vl.preprocess where it is further augmented
698-
// by replacing tokens according to imageTHW.
699-
//
700-
// Neither the structured content nor the postprocessing of the template
701-
// are supported in current Tokenizer/Jinja (swift) so handle that here.
702-
703-
var messages = prompt.asMessages()
704-
if messages[0]["role"] != "system" {
689+
private func prepareMessages(_ messages: [Message]) -> [Message] {
690+
var messages = messages
691+
print(messages)
692+
// Add system message if not present
693+
if let role = messages[0]["role"] as? String, role != "system" {
705694
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
706695
}
707696

708-
let lastIndex = messages.count - 1
709-
var lastMessage = messages[lastIndex]["content"] ?? ""
710-
711-
// image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
712-
let mergeLength = config.mergeSize * config.mergeSize
713-
for thw in imageTHW ?? [] {
714-
lastMessage += "<|vision_start|>"
715-
lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
716-
.joined()
717-
lastMessage += "<|vision_end|>"
718-
}
719-
720-
messages[lastIndex]["content"] = lastMessage
721-
722-
return
723-
messages
724-
.map {
725-
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
726-
}
727-
.joined(separator: "\n")
728-
+ "\n<|im_start|>assistant\n"
697+
return messages
729698
}
730699

700+
// public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String {
701+
// let messages = prepareMessages(prompt.asMessages())
702+
// let tokens = try tokenizer.applyChatTemplate(messages: messages)
703+
// return tokenizer.decode(tokens: tokens)
704+
// }
705+
731706
public func prepare(input: UserInput) throws -> LMInput {
707+
// Text-only input
732708
if input.images.isEmpty {
733-
// just a straight text prompt
734-
let prompt = prepare(prompt: input.prompt, imageTHW: nil)
735-
let promptTokens = try tokenizer.encode(text: prompt)
709+
let messages = input.prompt.asMessages()
710+
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
736711
return LMInput(tokens: MLXArray(promptTokens))
737712
}
738713

739-
// image_processing_qwen2_vl.preprocess
714+
// Input with images
740715
let images = try input.images.map {
741716
try preprocess(images: [$0.asCIImage()], processing: input.processing)
742717
}
743718
let pixels = concatenated(images.map { $0.0 })
744719
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })
745720

746-
// processing_qwen2_vl.Qwen2VLProcessor
747-
let prompt = prepare(prompt: input.prompt, imageTHW: image.imageGridThw)
748-
let promptTokens = try tokenizer.encode(text: prompt)
721+
// Get tokens from messages
722+
let messages = prepareMessages(input.prompt.asMessages())
723+
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
724+
725+
// Replace single image pad token with correct number for each image
726+
let imagePadToken = try tokenizer.encode(text: "<|image_pad|>").first!
727+
let mergeLength = config.mergeSize * config.mergeSize
728+
729+
// TODO: This assumes that there is only one image. A better solution is needed for the case when multiple images are included.
730+
if let imageGridThw = image.imageGridThw {
731+
for thw in imageGridThw {
732+
if let padIndex = promptTokens.firstIndex(of: imagePadToken) {
733+
let paddingCount = thw.product / mergeLength
734+
promptTokens.replaceSubrange(
735+
padIndex ... (padIndex),
736+
with: Array(repeating: imagePadToken, count: paddingCount)
737+
)
738+
}
739+
}
740+
}
741+
742+
// TODO: For debugging. Remove later.
743+
let promptTokensDecoded = try tokenizer.decode(tokens: promptTokens)
744+
print(promptTokensDecoded)
745+
749746
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
750747
let mask = ones(like: promptArray).asType(.int8)
751-
752748
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
753749
}
754750

Package.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ let package = Package(
2828
],
2929
dependencies: [
3030
.package(url: "https://github.com/ml-explore/mlx-swift", from: "0.21.2"),
31-
.package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"),
31+
// .package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"),
32+
.package(
33+
url: "https://github.com/DePasqualeOrg/swift-transformers", branch: "images-and-tools"),
3234
.package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"),
3335
.package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"),
3436
],

mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c",
2+
"originHash" : "327a4376ec20e25f941929e0bd2eefea67914f3c98414e5489f49c7e49eab7ab",
33
"pins" : [
44
{
55
"identity" : "gzipswift",
@@ -13,10 +13,10 @@
1313
{
1414
"identity" : "jinja",
1515
"kind" : "remoteSourceControl",
16-
"location" : "https://github.com/maiqingqiang/Jinja",
16+
"location" : "https://github.com/DePasqualeOrg/Jinja",
1717
"state" : {
18-
"revision" : "b435eb62b0d3d5f34167ec70a128355486981712",
19-
"version" : "1.0.5"
18+
"branch" : "add-functionality",
19+
"revision" : "c1a95c9fc3fb489bbdfebe0bd97f5ea529c3c8a9"
2020
}
2121
},
2222
{
@@ -33,8 +33,8 @@
3333
"kind" : "remoteSourceControl",
3434
"location" : "https://github.com/gonzalezreal/NetworkImage",
3535
"state" : {
36-
"revision" : "7aff8d1b31148d32c5933d75557d42f6323ee3d1",
37-
"version" : "6.0.0"
36+
"revision" : "2849f5323265386e200484b0d0f896e73c3411b9",
37+
"version" : "6.0.1"
3838
}
3939
},
4040
{
@@ -55,13 +55,31 @@
5555
"version" : "1.5.0"
5656
}
5757
},
58+
{
59+
"identity" : "swift-cmark",
60+
"kind" : "remoteSourceControl",
61+
"location" : "https://github.com/swiftlang/swift-cmark",
62+
"state" : {
63+
"revision" : "3ccff77b2dc5b96b77db3da0d68d28068593fa53",
64+
"version" : "0.5.0"
65+
}
66+
},
67+
{
68+
"identity" : "swift-collections",
69+
"kind" : "remoteSourceControl",
70+
"location" : "https://github.com/apple/swift-collections.git",
71+
"state" : {
72+
"revision" : "671108c96644956dddcd89dd59c203dcdb36cec7",
73+
"version" : "1.1.4"
74+
}
75+
},
5876
{
5977
"identity" : "swift-markdown-ui",
6078
"kind" : "remoteSourceControl",
6179
"location" : "https://github.com/gonzalezreal/swift-markdown-ui",
6280
"state" : {
63-
"revision" : "55441810c0f678c78ed7e2ebd46dde89228e02fc",
64-
"version" : "2.4.0"
81+
"revision" : "5f613358148239d0292c0cef674a3c2314737f9e",
82+
"version" : "2.4.1"
6583
}
6684
},
6785
{
@@ -76,12 +94,12 @@
7694
{
7795
"identity" : "swift-transformers",
7896
"kind" : "remoteSourceControl",
79-
"location" : "https://github.com/huggingface/swift-transformers",
97+
"location" : "https://github.com/DePasqualeOrg/swift-transformers",
8098
"state" : {
81-
"revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
82-
"version" : "0.1.13"
99+
"branch" : "images-and-tools",
100+
"revision" : "706e9a1ce783d68f3a1fba07a888f139148a9bfe"
83101
}
84102
}
85103
],
86-
"version" : 2
104+
"version" : 3
87105
}

0 commit comments

Comments
 (0)