@@ -367,10 +367,10 @@ private enum Vision {
367
367
}
368
368
369
369
public func callAsFunction(
370
- _ x: MLXArray , gridThw : [ THW ] , rotaryPositionEmbedding: MLXArray
370
+ _ x: MLXArray , frames : [ THW ] , rotaryPositionEmbedding: MLXArray
371
371
) -> MLXArray {
372
372
let sequenceLength = x. dim ( 0 )
373
- let B = gridThw [ 0 ] . t
373
+ let B = frames [ 0 ] . t
374
374
let L = sequenceLength / B
375
375
376
376
let qkv = qkv ( x)
@@ -435,13 +435,13 @@ private enum Vision {
435
435
}
436
436
437
437
func callAsFunction(
438
- _ hiddenStates: MLXArray , gridThw : [ THW ] , rotaryPositionEmbedding: MLXArray
438
+ _ hiddenStates: MLXArray , frames : [ THW ] , rotaryPositionEmbedding: MLXArray
439
439
) -> MLXArray {
440
440
var hiddenStates =
441
441
hiddenStates
442
442
+ attention(
443
443
norm1 ( hiddenStates) ,
444
- gridThw : gridThw ,
444
+ frames : frames ,
445
445
rotaryPositionEmbedding: rotaryPositionEmbedding
446
446
)
447
447
hiddenStates = hiddenStates + mlp( norm2 ( hiddenStates) )
@@ -479,10 +479,10 @@ private enum Vision {
479
479
spatialMergeSize: 2 )
480
480
}
481
481
482
- func rotaryPositionEmbedding( _ gridThw : [ THW ] ) -> MLXArray {
482
+ func rotaryPositionEmbedding( _ frames : [ THW ] ) -> MLXArray {
483
483
var positionIds = [ MLXArray] ( )
484
484
485
- for row in gridThw {
485
+ for row in frames {
486
486
let ( t, h, w) = row. values
487
487
488
488
var hposIds = expandedDimensions ( MLXArray ( 0 ..< h) , axis: 1 )
@@ -516,22 +516,22 @@ private enum Vision {
516
516
}
517
517
518
518
let indices = concatenated ( positionIds, axis: 0 )
519
- let maxGridSize = gridThw . lazy. map { max ( $0. h, $0. w) } . max ( ) ?? 0
520
- let rotaryPositionEmbedFull = rotaryPositionEmbedding ( sequenceLength: maxGridSize ) [
519
+ let maxFrameSize = frames . lazy. map { max ( $0. h, $0. w) } . max ( ) ?? 0
520
+ let rotaryPositionEmbedFull = rotaryPositionEmbedding ( sequenceLength: maxFrameSize ) [
521
521
indices]
522
522
523
523
return rotaryPositionEmbedFull. reshaped ( indices. dim ( 0 ) , - 1 )
524
524
}
525
525
526
- public func callAsFunction( _ hiddenStates: MLXArray , gridThw : [ THW ] ) -> MLXArray {
526
+ public func callAsFunction( _ hiddenStates: MLXArray , frames : [ THW ] ) -> MLXArray {
527
527
var hiddenStates = patchEmbed ( hiddenStates)
528
- let rotaryPositionEmbedding = rotaryPositionEmbedding ( gridThw )
528
+ let rotaryPositionEmbedding = rotaryPositionEmbedding ( frames )
529
529
530
- let batchSize = gridThw . count
530
+ let batchSize = frames . count
531
531
532
532
for block in blocks {
533
533
hiddenStates = block (
534
- hiddenStates, gridThw : gridThw ,
534
+ hiddenStates, frames : frames ,
535
535
rotaryPositionEmbedding: rotaryPositionEmbedding)
536
536
}
537
537
@@ -585,6 +585,10 @@ private enum Vision {
585
585
/// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``.
586
586
public class Qwen2VLProcessor : UserInputProcessor {
587
587
588
+ enum Qwen2VLProcessorError : Error {
589
+ case framesIsNil
590
+ }
591
+
588
592
private let config : Qwen2VLProcessorConfiguration
589
593
private let tokenizer : any Tokenizer
590
594
@@ -739,61 +743,116 @@ public class Qwen2VLProcessor: UserInputProcessor {
739
743
+ " \n <|im_start|>assistant \n "
740
744
}
741
745
742
- public func prepare( input: UserInput ) async throws -> LMInput {
743
- if input. images. isEmpty && input. videos. isEmpty {
744
- // just a straight text prompt
745
- let prompt = prepare ( prompt: input. prompt, imageTHW: nil , videoTHW: nil )
746
- let promptTokens = try tokenizer. encode ( text: prompt)
747
- return LMInput ( tokens: MLXArray ( promptTokens) )
746
+ private func prepareMessages( _ messages: [ Message ] ) -> [ Message ] {
747
+ var messages = messages
748
+ // Add system message if not present
749
+ if let role = messages [ 0 ] [ " role " ] as? String , role != " system " {
750
+ messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
748
751
}
752
+ return messages
753
+ }
749
754
750
- // image_processing_qwen2_vl.preprocess
751
- let images = try input. images. map {
755
+ // public func prepare(prompt: UserInput.Prompt, frames: [THW]?) throws -> String {
756
+ // let messages = prepareMessages(prompt.asMessages())
757
+ // let tokens = try tokenizer.applyChatTemplate(messages: messages)
758
+ // return tokenizer.decode(tokens: tokens)
759
+ // }
760
+
761
+ public func prepare( input: UserInput ) throws -> LMInput {
762
+ // Text-only input
763
+ if input. images. isEmpty {
764
+ let messages = input. prompt. asMessages ( )
765
+ let promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
766
+ return LMInput ( tokens: MLXArray ( promptTokens) )
767
+ }
768
+ // Input with images
769
+ let pixelsAndFrames = try input. images. map {
752
770
try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
753
771
}
754
772
755
- var videosAsImageSequences = [ [ CIImage] ] ( )
756
- for video in input . videos {
757
- if let imageSequence = try ? await MediaProcessing . asCIImageSequence (
758
- video. asAVAsset ( ) , samplesPerSecond: 2 )
759
- {
760
- videosAsImageSequences. append ( imageSequence)
761
- }
773
+ // var videosAsImageSequences = [[CIImage]]()
774
+ // for video in input.videos {
775
+ // if let imageSequence = try? await MediaProcessing.asCIImageSequence(
776
+ // video.asAVAsset(), samplesPerSecond: 2)
777
+ // {
778
+ // videosAsImageSequences.append(imageSequence)
779
+ // }
780
+ // }
781
+ // let videos = try videosAsImageSequences.map {
782
+ // try preprocess(images: $0, processing: input.processing)
783
+ // }
784
+
785
+ // let imagePixels: MLXArray?
786
+ // let image: LMInput.ProcessedImage?
787
+ // if !images.isEmpty {
788
+ // imagePixels = concatenated(images.map { $0.0 })
789
+ // image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 })
790
+ // } else {
791
+ // imagePixels = nil
792
+ // image = nil
793
+ // }
794
+
795
+ // let videoPixels: MLXArray?
796
+ // let video: LMInput.ProcessedVideo?
797
+ // if !videos.isEmpty {
798
+ // videoPixels = concatenated(videos.map { $0.0 })
799
+ // video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 })
800
+ // } else {
801
+ // videoPixels = nil
802
+ // video = nil
803
+ // }
804
+
805
+ // // processing_qwen2_vl.Qwen2VLProcessor
806
+ // let prompt = prepare(
807
+ // prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw)
808
+ // let promptTokens = try tokenizer.encode(text: prompt)
809
+ // let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
810
+ // let mask = ones(like: promptArray).asType(.int8)
811
+
812
+ // return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video)
813
+ let pixelsConcatenated = concatenated ( pixelsAndFrames. map { $0. 0 } )
814
+ let image = LMInput . ProcessedImage (
815
+ pixels: pixelsConcatenated, frames: pixelsAndFrames. map { $0. 1 } )
816
+ let messages = prepareMessages ( input. prompt. asMessages ( ) )
817
+ var promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
818
+ // Replace single image pad token with correct number for each image
819
+ let mergeLength = config. mergeSize * config. mergeSize
820
+ let imagePlaceholderTokens = try tokenizer. encode (
821
+ text: " <|vision_start|><|image_pad|><|vision_end|> " )
822
+ guard let frames = image. frames else {
823
+ throw Qwen2VLProcessorError . framesIsNil
762
824
}
763
- let videos = try videosAsImageSequences. map {
764
- try preprocess ( images: $0, processing: input. processing)
825
+ let placeholderRanges = promptTokens. ranges ( of: imagePlaceholderTokens)
826
+ guard placeholderRanges. count == frames. count else {
827
+ throw VLMError . processing (
828
+ " Number of image placeholders does not match number of frames " )
765
829
}
766
-
767
- let imagePixels: MLXArray?
768
- let image: LMInput. ProcessedImage?
769
- if !images. isEmpty {
770
- imagePixels = concatenated ( images. map { $0. 0 } )
771
- image = LMInput . ProcessedImage ( pixels: imagePixels!, imageGridThw: images. map { $0. 1 } )
772
- } else {
773
- imagePixels = nil
774
- image = nil
830
+ let replacementSequences = try frames. map { thw in
831
+ let paddingCount = thw. product / mergeLength
832
+ return try tokenizer. encode (
833
+ text:
834
+ " <|vision_start|> \( Array ( repeating: " <|image_pad|> " , count: paddingCount) . joined ( ) ) <|vision_end|> "
835
+ )
775
836
}
776
-
777
- let videoPixels : MLXArray ?
778
- let video : LMInput . ProcessedVideo ?
779
- if !videos . isEmpty {
780
- videoPixels = concatenated ( videos . map { $0 . 0 } )
781
- video = LMInput . ProcessedVideo ( pixels : videoPixels! , videoGridThw : videos . map { $0 . 1 } )
782
- } else {
783
- videoPixels = nil
784
- video = nil
837
+ // Build the final array
838
+ var result : [ Int ] = [ ]
839
+ var currentIndex = promptTokens . startIndex
840
+ for (range , replacement ) in zip ( placeholderRanges , replacementSequences ) {
841
+ // Add tokens before the placeholder
842
+ result . append ( contentsOf : promptTokens [ currentIndex ..< range . lowerBound ] )
843
+ // Add replacement sequence
844
+ result . append ( contentsOf : replacement )
845
+ currentIndex = range . upperBound
785
846
}
786
-
787
- // processing_qwen2_vl.Qwen2VLProcessor
788
- let prompt = prepare (
789
- prompt : input . prompt , imageTHW : image ? . imageGridThw , videoTHW : video ? . videoGridThw )
790
- let promptTokens = try tokenizer . encode ( text : prompt )
847
+ // Add any remaining tokens after the last replacement
848
+ if currentIndex < promptTokens . endIndex {
849
+ result . append ( contentsOf : promptTokens [ currentIndex ... ] )
850
+ }
851
+ promptTokens = result
791
852
let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
792
853
let mask = ones ( like: promptArray) . asType ( . int8)
793
-
794
- return LMInput( text: . init( tokens: promptArray, mask: mask) , image: image, video: video)
854
+ return LMInput ( text: . init( tokens: promptArray, mask: mask) , image: image)
795
855
}
796
-
797
856
}
798
857
799
858
// MARK: - Model
@@ -821,18 +880,18 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
821
880
self . _languageModel. wrappedValue = Language . LanguageModel ( config. textConfiguration)
822
881
}
823
882
824
- private func inputEmbeddings( inputIds: MLXArray , pixelValues: MLXArray ? , gridThw : [ THW ] ? )
883
+ private func inputEmbeddings( inputIds: MLXArray , pixelValues: MLXArray ? , frames : [ THW ] ? )
825
884
-> MLXArray
826
885
{
827
- guard let pixelValues, let gridThw else {
886
+ guard let pixelValues, let frames else {
828
887
return languageModel. model. embedTokens ( inputIds [ . newAxis, . ellipsis] )
829
888
}
830
889
831
890
// Get the input embeddings from the language model
832
891
let inputEmbeds = languageModel. model. embedTokens ( inputIds)
833
892
834
893
// Get the ouptut hidden states from the vision model
835
- var hiddenStates = self . visionModel ( pixelValues, gridThw : gridThw )
894
+ var hiddenStates = self . visionModel ( pixelValues, frames : frames )
836
895
837
896
if hiddenStates. ndim == 2 {
838
897
hiddenStates = hiddenStates [ . newAxis, 0 ... , 0 ... ]
@@ -871,6 +930,8 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
871
930
public func prepare( _ input: LMInput , cache: [ any KVCache ] , windowSize: Int ? ) throws
872
931
-> PrepareResult
873
932
{
933
+ let frames = input. image? . frames
934
+
874
935
let dtype = visionModel. patchEmbed. proj. weight. dtype
875
936
876
937
let imageGridThw = input. image? . imageGridThw
@@ -891,7 +952,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
891
952
}
892
953
893
954
let inputEmbeddings = self . inputEmbeddings (
894
- inputIds: input. text. tokens, pixelValues: pixels, gridThw : gridThw )
955
+ inputIds: input. text. tokens, pixelValues: pixels, frames : frames )
895
956
896
957
let result = languageModel ( nil , cache: cache, inputEmbedding: inputEmbeddings)
897
958
0 commit comments