Skip to content

Commit 806c7f2

Browse files
committed
Working
1 parent c1deeb4 commit 806c7f2

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -686,37 +686,22 @@ public class Qwen2VLProcessor: UserInputProcessor {
686686
return (flattenedPatches, .init(gridT, gridH, gridW))
687687
}
688688

689-
private func prepareMessages(_ messages: [Message], imageTHW: [THW]?) -> [Message] {
689+
private func prepareMessages(_ messages: [Message]) -> [Message] {
690690
var messages = messages
691691
print(messages)
692692
// Add system message if not present
693693
if let role = messages[0]["role"] as? String, role != "system" {
694694
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
695695
}
696696

697-
// // Add image markers to last message if needed
698-
// if let imageTHW {
699-
// let lastIndex = messages.count - 1
700-
// var content = messages[lastIndex]["content"] as? String ?? ""
701-
// let mergeLength = config.mergeSize * config.mergeSize
702-
// for thw in imageTHW {
703-
// content += "<|vision_start|>"
704-
// content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
705-
// content += "<|vision_end|>"
706-
// }
707-
// messages[lastIndex]["content"] = content
708-
// }
709-
710-
// TODO: Instead of the above, replace the single `<|image_pad|>` with repeated padding, using the same logic as above to determine the number of repeats.
711-
712697
return messages
713698
}
714699

715-
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String {
716-
let messages = prepareMessages(prompt.asMessages(), imageTHW: imageTHW)
717-
let tokens = try tokenizer.applyChatTemplate(messages: messages)
718-
return tokenizer.decode(tokens: tokens)
719-
}
700+
// public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String {
701+
// let messages = prepareMessages(prompt.asMessages())
702+
// let tokens = try tokenizer.applyChatTemplate(messages: messages)
703+
// return tokenizer.decode(tokens: tokens)
704+
// }
720705

721706
public func prepare(input: UserInput) throws -> LMInput {
722707
// Text-only input
@@ -725,15 +710,34 @@ public class Qwen2VLProcessor: UserInputProcessor {
725710
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
726711
return LMInput(tokens: MLXArray(promptTokens))
727712
}
713+
728714
// Input with images
729715
let images = try input.images.map {
730716
try preprocess(images: [$0.asCIImage()], processing: input.processing)
731717
}
732718
let pixels = concatenated(images.map { $0.0 })
733719
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })
734-
// Prepare messages with image markers
735-
let messages = prepareMessages(input.prompt.asMessages(), imageTHW: image.imageGridThw)
736-
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
720+
721+
// Get tokens from messages
722+
let messages = prepareMessages(input.prompt.asMessages())
723+
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
724+
725+
// Replace single image pad token with correct number for each image
726+
let imagePadToken = try tokenizer.encode(text: "<|image_pad|>").first!
727+
let mergeLength = config.mergeSize * config.mergeSize
728+
729+
// TODO: This assumes that there is only one image. A better solution is needed for the case when multiple images are included.
730+
if let imageGridThw = image.imageGridThw {
731+
for thw in imageGridThw {
732+
if let padIndex = promptTokens.firstIndex(of: imagePadToken) {
733+
let paddingCount = thw.product / mergeLength
734+
promptTokens.replaceSubrange(
735+
padIndex ... (padIndex),
736+
with: Array(repeating: imagePadToken, count: paddingCount)
737+
)
738+
}
739+
}
740+
}
737741

738742
// TODO: For debugging. Remove later.
739743
let promptTokensDecoded = try tokenizer.decode(tokens: promptTokens)

0 commit comments

Comments
 (0)