@@ -686,53 +686,60 @@ public class Qwen2VLProcessor: UserInputProcessor {
686
686
return ( flattenedPatches, . init( gridT, gridH, gridW) )
687
687
}
688
688
689
- public func prepare( prompt: UserInput . Prompt , imageTHW: [ THW ] ? ) -> String {
690
- // the tokenizer does have a chat template and it expects messages
691
- // like this:
689
+ public func prepare( prompt: UserInput . Prompt , imageTHW: [ THW ] ? ) throws -> [ Int ] {
690
+ // The tokenizer has a chat template and expects messages like this:
692
691
//
693
692
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
694
693
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
695
694
//
696
695
// The output of the prompt template is fed into
697
696
// image_processing_qwen2_vl.preprocess where it is further augmented
698
697
// by replacing tokens according to imageTHW.
699
- //
700
- // Neither the structured content nor the postprocessing of the template
701
- // are supported in current Tokenizer/Jinja (swift) so handle that here.
702
-
703
- var messages = prompt. asMessages ( )
704
- if messages [ 0 ] [ " role " ] != " system " {
705
- messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
706
- }
707
698
708
- let lastIndex = messages. count - 1
709
- var lastMessage = messages [ lastIndex] [ " content " ] ?? " "
699
+ var messages = {
700
+ switch prompt {
701
+ case . messages( let messages) :
702
+ // TODO: Handle this case
703
+ // ...
704
+ var messages = messages
705
+ if messages [ 0 ] [ " role " ] != " system " {
706
+ messages. insert (
707
+ [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
708
+ }
709
+ return messages as [ [ String : Any ] ]
710
+ case . text( let text) :
711
+ return [
712
+ [
713
+ " role " : " user " ,
714
+ " content " : [
715
+ [
716
+ " type " : " text " ,
717
+ " text " : text,
718
+ ] ,
719
+ [
720
+ " type " : " image " ,
721
+ " image " : " base64_encoded_image_data " , // TODO: Image data goes here?
722
+ ] ,
723
+ ] ,
724
+ ]
725
+ ] as [ [ String : Any ] ]
726
+ }
727
+ } ( )
710
728
711
- // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
712
- let mergeLength = config. mergeSize * config. mergeSize
713
729
for thw in imageTHW ?? [ ] {
714
- lastMessage += " <|vision_start|> "
715
- lastMessage += Array ( repeating: " <|image_pad|> " , count: thw. product / mergeLength)
716
- . joined ( )
717
- lastMessage += " <|vision_end|> "
730
+ // TODO: Add images to messages
731
+ // ...
718
732
}
719
733
720
- messages [ lastIndex ] [ " content " ] = lastMessage
734
+ print ( messages)
721
735
722
- return
723
- messages
724
- . map {
725
- " <|im_start|> \( $0 [ " role " ] ?? " user " ) \n \( $0 [ " content " ] ?? " " ) <|im_end|> "
726
- }
727
- . joined ( separator: " \n " )
728
- + " \n <|im_start|>assistant \n "
736
+ return try tokenizer. applyChatTemplate ( messages: messages)
729
737
}
730
738
731
739
public func prepare( input: UserInput ) throws -> LMInput {
732
740
if input. images. isEmpty {
733
741
// just a straight text prompt
734
- let prompt = prepare ( prompt: input. prompt, imageTHW: nil )
735
- let promptTokens = try tokenizer. encode ( text: prompt)
742
+ let promptTokens = try prepare ( prompt: input. prompt, imageTHW: nil )
736
743
return LMInput ( tokens: MLXArray ( promptTokens) )
737
744
}
738
745
@@ -744,8 +751,7 @@ public class Qwen2VLProcessor: UserInputProcessor {
744
751
let image = LMInput . ProcessedImage ( pixels: pixels, imageGridThw: images. map { $0. 1 } )
745
752
746
753
// processing_qwen2_vl.Qwen2VLProcessor
747
- let prompt = prepare ( prompt: input. prompt, imageTHW: image. imageGridThw)
748
- let promptTokens = try tokenizer. encode ( text: prompt)
754
+ let promptTokens = try prepare ( prompt: input. prompt, imageTHW: image. imageGridThw)
749
755
let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
750
756
let mask = ones ( like: promptArray) . asType ( . int8)
751
757
0 commit comments