Skip to content

Commit 3ac6296

Browse files
committed
Use chat template for Qwen 2 VL
1 parent 8cdc4a0 commit 3ac6296

File tree

3 files changed

+73
-56
lines changed

3 files changed

+73
-56
lines changed

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -686,69 +686,66 @@ public class Qwen2VLProcessor: UserInputProcessor {
686686
return (flattenedPatches, .init(gridT, gridH, gridW))
687687
}
688688

689-
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) -> String {
690-
// the tokenizer does have a chat template and it expects messages
691-
// like this:
692-
//
693-
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
694-
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
695-
//
696-
// The output of the prompt template is fed into
697-
// image_processing_qwen2_vl.preprocess where it is further augmented
698-
// by replacing tokens according to imageTHW.
699-
//
700-
// Neither the structured content nor the postprocessing of the template
701-
// are supported in current Tokenizer/Jinja (swift) so handle that here.
702-
689+
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String {
703690
var messages = prompt.asMessages()
704691
if messages[0]["role"] != "system" {
705692
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
706693
}
707-
694+
// For the last message, we need to add image markers to the content
708695
let lastIndex = messages.count - 1
709-
var lastMessage = messages[lastIndex]["content"] ?? ""
710-
711-
// image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
696+
let lastContent = messages[lastIndex]["content"] ?? ""
697+
// Build the content string with image markers
712698
let mergeLength = config.mergeSize * config.mergeSize
713-
for thw in imageTHW ?? [] {
714-
lastMessage += "<|vision_start|>"
715-
lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
716-
.joined()
717-
lastMessage += "<|vision_end|>"
718-
}
719-
720-
messages[lastIndex]["content"] = lastMessage
721-
722-
return
723-
messages
724-
.map {
725-
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
699+
var content = lastContent
700+
if let imageTHW = imageTHW {
701+
for thw in imageTHW {
702+
content += "<|vision_start|>"
703+
content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
704+
.joined()
705+
content += "<|vision_end|>"
726706
}
727-
.joined(separator: "\n")
728-
+ "\n<|im_start|>assistant\n"
707+
}
708+
// Update the last message with the combined content
709+
messages[lastIndex]["content"] = content
710+
let tokens = try tokenizer.applyChatTemplate(messages: messages)
711+
return tokenizer.decode(tokens: tokens)
729712
}
730713

731714
public func prepare(input: UserInput) throws -> LMInput {
715+
// Text-only input
732716
if input.images.isEmpty {
733-
// just a straight text prompt
734-
let prompt = prepare(prompt: input.prompt, imageTHW: nil)
735-
let promptTokens = try tokenizer.encode(text: prompt)
717+
let messages = input.prompt.asMessages()
718+
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
736719
return LMInput(tokens: MLXArray(promptTokens))
737720
}
738-
739-
// image_processing_qwen2_vl.preprocess
721+
// Input with images
740722
let images = try input.images.map {
741723
try preprocess(images: [$0.asCIImage()], processing: input.processing)
742724
}
743725
let pixels = concatenated(images.map { $0.0 })
744726
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })
745-
746-
// processing_qwen2_vl.Qwen2VLProcessor
747-
let prompt = prepare(prompt: input.prompt, imageTHW: image.imageGridThw)
748-
let promptTokens = try tokenizer.encode(text: prompt)
727+
// Create structured messages with image markers
728+
var messages = input.prompt.asMessages()
729+
if messages[0]["role"] != "system" {
730+
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
731+
}
732+
// Structure the last message to include both text and image markers
733+
let lastIndex = messages.count - 1
734+
let lastContent = messages[lastIndex]["content"] ?? ""
735+
// Build the content string with image markers
736+
let mergeLength = config.mergeSize * config.mergeSize
737+
var content = lastContent
738+
for thw in image.imageGridThw ?? [] {
739+
content += "<|vision_start|>"
740+
content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
741+
content += "<|vision_end|>"
742+
}
743+
// Update the last message with the combined content
744+
messages[lastIndex]["content"] = content
745+
// Use the chat template to generate the prompt
746+
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
749747
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
750748
let mask = ones(like: promptArray).asType(.int8)
751-
752749
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
753750
}
754751

Package.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ let package = Package(
2828
],
2929
dependencies: [
3030
.package(url: "https://github.com/ml-explore/mlx-swift", from: "0.21.2"),
31-
.package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"),
31+
// .package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"),
32+
.package(
33+
url: "https://github.com/DePasqualeOrg/swift-transformers", branch: "images-and-tools"),
3234
.package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"),
3335
.package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"),
3436
],

mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

Lines changed: 30 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)