@@ -686,37 +686,22 @@ public class Qwen2VLProcessor: UserInputProcessor {
686
686
return ( flattenedPatches, . init( gridT, gridH, gridW) )
687
687
}
688
688
689
- private func prepareMessages( _ messages: [ Message ] , imageTHW : [ THW ] ? ) -> [ Message ] {
689
+ private func prepareMessages( _ messages: [ Message ] ) -> [ Message ] {
690
690
var messages = messages
691
691
print ( messages)
692
692
// Add system message if not present
693
693
if let role = messages [ 0 ] [ " role " ] as? String , role != " system " {
694
694
messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
695
695
}
696
696
697
- // // Add image markers to last message if needed
698
- // if let imageTHW {
699
- // let lastIndex = messages.count - 1
700
- // var content = messages[lastIndex]["content"] as? String ?? ""
701
- // let mergeLength = config.mergeSize * config.mergeSize
702
- // for thw in imageTHW {
703
- // content += "<|vision_start|>"
704
- // content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
705
- // content += "<|vision_end|>"
706
- // }
707
- // messages[lastIndex]["content"] = content
708
- // }
709
-
710
- // TODO: Instead of the above, replace the single `<|image_pad|>` with repeated padding, using the same logic as above to determine the number of repeats.
711
-
712
697
return messages
713
698
}
714
699
715
- public func prepare( prompt: UserInput . Prompt , imageTHW: [ THW ] ? ) throws -> String {
716
- let messages = prepareMessages ( prompt. asMessages ( ) , imageTHW : imageTHW )
717
- let tokens = try tokenizer. applyChatTemplate ( messages: messages)
718
- return tokenizer. decode ( tokens: tokens)
719
- }
700
+ // public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String {
701
+ // let messages = prepareMessages(prompt.asMessages())
702
+ // let tokens = try tokenizer.applyChatTemplate(messages: messages)
703
+ // return tokenizer.decode(tokens: tokens)
704
+ // }
720
705
721
706
public func prepare( input: UserInput ) throws -> LMInput {
722
707
// Text-only input
@@ -725,15 +710,34 @@ public class Qwen2VLProcessor: UserInputProcessor {
725
710
let promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
726
711
return LMInput ( tokens: MLXArray ( promptTokens) )
727
712
}
713
+
728
714
// Input with images
729
715
let images = try input. images. map {
730
716
try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
731
717
}
732
718
let pixels = concatenated ( images. map { $0. 0 } )
733
719
let image = LMInput . ProcessedImage ( pixels: pixels, imageGridThw: images. map { $0. 1 } )
734
- // Prepare messages with image markers
735
- let messages = prepareMessages ( input. prompt. asMessages ( ) , imageTHW: image. imageGridThw)
736
- let promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
720
+
721
+ // Get tokens from messages
722
+ let messages = prepareMessages ( input. prompt. asMessages ( ) )
723
+ var promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
724
+
725
+ // Replace single image pad token with correct number for each image
726
+ let imagePadToken = try tokenizer. encode ( text: " <|image_pad|> " ) . first!
727
+ let mergeLength = config. mergeSize * config. mergeSize
728
+
729
+ // TODO: This assumes that there is only one image. A better solution is needed for the case when multiple images are included.
730
+ if let imageGridThw = image. imageGridThw {
731
+ for thw in imageGridThw {
732
+ if let padIndex = promptTokens. firstIndex ( of: imagePadToken) {
733
+ let paddingCount = thw. product / mergeLength
734
+ promptTokens. replaceSubrange (
735
+ padIndex ... ( padIndex) ,
736
+ with: Array ( repeating: imagePadToken, count: paddingCount)
737
+ )
738
+ }
739
+ }
740
+ }
737
741
738
742
// TODO: For debugging. Remove later.
739
743
let promptTokensDecoded = try tokenizer. decode ( tokens: promptTokens)
0 commit comments