@@ -694,147 +694,83 @@ public class Qwen2VLProcessor: UserInputProcessor {
694
694
return ( flattenedPatches, . init( gridT, gridH, gridW) )
695
695
}
696
696
697
- public func prepare( prompt: UserInput . Prompt , imageTHW: [ THW ] ? , videoTHW: [ THW ] ? ) -> String {
698
- // the tokenizer does have a chat template and it expects messages
699
- // like this:
700
- //
701
- // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
702
- // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
703
- //
704
- // The output of the prompt template is fed into
705
- // image_processing_qwen2_vl.preprocess where it is further augmented
706
- // by replacing tokens according to imageTHW.
707
- //
708
- // Neither the structured content nor the postprocessing of the template
709
- // are supported in current Tokenizer/Jinja (swift) so handle that here.
710
-
711
- var messages = prompt. asMessages ( )
712
- if messages [ 0 ] [ " role " ] != " system " {
713
- messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
697
+ public func prepare( input: UserInput ) async throws -> LMInput {
698
+ let messages = input. prompt. asMessages ( )
699
+ var promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
700
+ // Text-only input
701
+ if input. images. isEmpty, input. videos. isEmpty {
702
+ return LMInput ( tokens: MLXArray ( promptTokens) )
714
703
}
715
-
716
- let lastIndex = messages. count - 1
717
- var lastMessage = messages [ lastIndex] [ " content " ] ?? " "
718
-
719
- // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
720
- let mergeLength = config. mergeSize * config. mergeSize
721
- for thw in imageTHW ?? [ ] {
722
- lastMessage += " <|vision_start|> "
723
- lastMessage += Array ( repeating: " <|image_pad|> " , count: thw. product / mergeLength)
724
- . joined ( )
725
- lastMessage += " <|vision_end|> "
704
+ // Input with images and/or videos
705
+ // Image processing
706
+ let imagePixelsAndFrames = try input. images. map {
707
+ try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
726
708
}
727
-
728
- for thw in videoTHW ?? [ ] {
729
- lastMessage += " <|vision_start|> "
730
- lastMessage += Array ( repeating: " <|video_pad|> " , count: thw. product / mergeLength)
731
- . joined ( )
732
- lastMessage += " <|vision_end|> "
709
+ let processedImage : LMInput . ProcessedImage ?
710
+ if !imagePixelsAndFrames. isEmpty {
711
+ let imagePixelsConcatenated = concatenated ( imagePixelsAndFrames. map { $0. 0 } )
712
+ processedImage = LMInput . ProcessedImage (
713
+ pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames. map { $0. 1 } )
714
+ if let imageFrames = processedImage? . frames {
715
+ // Replace padding for images
716
+ promptTokens = try replacePlaceholderTokens (
717
+ in: promptTokens, frames: imageFrames, paddingToken: " <|image_pad|> " )
718
+ }
719
+ } else {
720
+ processedImage = nil
733
721
}
734
-
735
- messages [ lastIndex ] [ " content " ] = lastMessage
736
-
737
- return
738
- messages
739
- . map {
740
- " <|im_start|> \( $0 [ " role " ] ?? " user " ) \n \( $0 [ " content " ] ?? " " ) <|im_end|> "
722
+ // Video processing
723
+ var videosAsImageSequences = [ [ CIImage ] ] ( )
724
+ for video in input . videos {
725
+ if let imageSequence = try ? await MediaProcessing . asCIImageSequence (
726
+ video . asAVAsset ( ) , samplesPerSecond : 2 )
727
+ {
728
+ videosAsImageSequences . append ( imageSequence )
741
729
}
742
- . joined ( separator: " \n " )
743
- + " \n <|im_start|>assistant \n "
744
- }
745
-
746
- private func prepareMessages( _ messages: [ Message ] ) -> [ Message ] {
747
- var messages = messages
748
- // Add system message if not present
749
- if let role = messages [ 0 ] [ " role " ] as? String , role != " system " {
750
- messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
751
730
}
752
- return messages
753
- }
754
-
755
- // public func prepare(prompt: UserInput.Prompt, frames: [THW]?) throws -> String {
756
- // let messages = prepareMessages(prompt.asMessages())
757
- // let tokens = try tokenizer.applyChatTemplate(messages: messages)
758
- // return tokenizer.decode(tokens: tokens)
759
- // }
760
-
761
- public func prepare( input: UserInput ) throws -> LMInput {
762
- // Text-only input
763
- if input. images. isEmpty {
764
- let messages = input. prompt. asMessages ( )
765
- let promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
766
- return LMInput ( tokens: MLXArray ( promptTokens) )
731
+ let videoPixelsAndFrames = try videosAsImageSequences. map {
732
+ try preprocess ( images: $0, processing: input. processing)
767
733
}
768
- // Input with images
769
- let pixelsAndFrames = try input. images. map {
770
- try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
734
+ let processedVideo: LMInput. ProcessedVideo?
735
+ if !videoPixelsAndFrames. isEmpty {
736
+ let videoPixelsConcatenated = concatenated ( videoPixelsAndFrames. map { $0. 0 } )
737
+ processedVideo = LMInput . ProcessedVideo (
738
+ pixels: videoPixelsConcatenated, videoGridThw: videoPixelsAndFrames. map { $0. 1 } )
739
+ if let videoFrames = processedVideo? . frames {
740
+ promptTokens = try replacePlaceholderTokens (
741
+ in: promptTokens, frames: videoFrames, paddingToken: " <|video_pad|> " )
742
+ }
743
+ } else {
744
+ processedVideo = nil
771
745
}
746
+ //
747
+ let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
748
+ let mask = ones ( like: promptArray) . asType ( . int8)
749
+ return LMInput(
750
+ text: . init( tokens: promptArray, mask: mask) , image: processedImage,
751
+ video: processedVideo)
752
+ }
772
753
773
- // var videosAsImageSequences = [[CIImage]]()
774
- // for video in input.videos {
775
- // if let imageSequence = try? await MediaProcessing.asCIImageSequence(
776
- // video.asAVAsset(), samplesPerSecond: 2)
777
- // {
778
- // videosAsImageSequences.append(imageSequence)
779
- // }
780
- // }
781
- // let videos = try videosAsImageSequences.map {
782
- // try preprocess(images: $0, processing: input.processing)
783
- // }
784
-
785
- // let imagePixels: MLXArray?
786
- // let image: LMInput.ProcessedImage?
787
- // if !images.isEmpty {
788
- // imagePixels = concatenated(images.map { $0.0 })
789
- // image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 })
790
- // } else {
791
- // imagePixels = nil
792
- // image = nil
793
- // }
794
-
795
- // let videoPixels: MLXArray?
796
- // let video: LMInput.ProcessedVideo?
797
- // if !videos.isEmpty {
798
- // videoPixels = concatenated(videos.map { $0.0 })
799
- // video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 })
800
- // } else {
801
- // videoPixels = nil
802
- // video = nil
803
- // }
804
-
805
- // // processing_qwen2_vl.Qwen2VLProcessor
806
- // let prompt = prepare(
807
- // prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw)
808
- // let promptTokens = try tokenizer.encode(text: prompt)
809
- // let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
810
- // let mask = ones(like: promptArray).asType(.int8)
811
-
812
- // return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video)
813
- let pixelsConcatenated = concatenated ( pixelsAndFrames. map { $0. 0 } )
814
- let image = LMInput . ProcessedImage (
815
- pixels: pixelsConcatenated, frames: pixelsAndFrames. map { $0. 1 } )
816
- let messages = prepareMessages ( input. prompt. asMessages ( ) )
817
- var promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
818
- // Replace single image pad token with correct number for each image
819
- let mergeLength = config. mergeSize * config. mergeSize
820
- let imagePlaceholderTokens = try tokenizer. encode (
821
- text: " <|vision_start|><|image_pad|><|vision_end|> " )
822
- guard let frames = image. frames else {
823
- throw Qwen2VLProcessorError . framesIsNil
824
- }
825
- let placeholderRanges = promptTokens. ranges ( of: imagePlaceholderTokens)
754
+ func replacePlaceholderTokens( in promptTokens: [ Int] , frames: [ THW] , paddingToken: String)
755
+ throws -> [ Int]
756
+ {
757
+ // Replace single padding token with correct number for each image
758
+ let placeholderTokens = try tokenizer. encode (
759
+ text: " <|vision_start|> \( paddingToken) <|vision_end|> " )
760
+ let placeholderRanges = promptTokens. ranges ( of: placeholderTokens)
826
761
guard placeholderRanges. count == frames. count else {
827
762
throw VLMError . processing (
828
- " Number of image placeholders does not match number of frames " )
763
+ " Number of placeholder tokens does not match number of frames " )
829
764
}
830
- let replacementSequences = try frames. map { thw in
831
- let paddingCount = thw. product / mergeLength
765
+ let mergeLength = config. mergeSize * config. mergeSize
766
+ let replacementSequences = try frames. map { frame in
767
+ let paddingCount = frame. product / mergeLength
832
768
return try tokenizer. encode (
833
769
text:
834
- " <|vision_start|> \( Array ( repeating: " <|image_pad|> " , count: paddingCount) . joined ( ) ) <|vision_end|> "
770
+ " <|vision_start|> \( Array ( repeating: paddingToken , count: paddingCount) . joined ( ) ) <|vision_end|> "
835
771
)
836
772
}
837
- // Build the final array
773
+ // Build the final array (images)
838
774
var result : [ Int ] = [ ]
839
775
var currentIndex = promptTokens. startIndex
840
776
for (range, replacement) in zip ( placeholderRanges, replacementSequences) {
@@ -848,10 +784,7 @@ public class Qwen2VLProcessor: UserInputProcessor {
848
784
if currentIndex < promptTokens. endIndex {
849
785
result. append ( contentsOf: promptTokens [ currentIndex... ] )
850
786
}
851
- promptTokens = result
852
- let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
853
- let mask = ones ( like: promptArray) . asType ( . int8)
854
- return LMInput ( text: . init( tokens: promptArray, mask: mask) , image: image)
787
+ return result
855
788
}
856
789
}
857
790
@@ -934,17 +867,17 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
934
867
935
868
let dtype = visionModel. patchEmbed. proj. weight. dtype
936
869
937
- let imageGridThw = input. image? . imageGridThw
870
+ let imageFrames = input. image? . frames
938
871
let imagePixels = input. image? . pixels. asType ( dtype)
939
872
940
- let videoGridThw = input. video? . videoGridThw
873
+ let videoGridThw = input. video? . frames
941
874
let videoPixels = input. video? . pixels. asType ( dtype)
942
875
943
876
let gridThw : [ THW ] ?
944
877
let pixels : MLXArray ?
945
878
946
879
if videoGridThw == nil {
947
- gridThw = imageGridThw
880
+ gridThw = imageFrames
948
881
pixels = imagePixels
949
882
} else {
950
883
gridThw = videoGridThw
0 commit comments