@@ -367,10 +367,10 @@ private enum Vision {
367
367
}
368
368
369
369
public func callAsFunction(
370
- _ x: MLXArray , gridThw : [ THW ] , rotaryPositionEmbedding: MLXArray
370
+ _ x: MLXArray , frames : [ THW ] , rotaryPositionEmbedding: MLXArray
371
371
) -> MLXArray {
372
372
let sequenceLength = x. dim ( 0 )
373
- let B = gridThw [ 0 ] . t
373
+ let B = frames [ 0 ] . t
374
374
let L = sequenceLength / B
375
375
376
376
let qkv = qkv ( x)
@@ -435,13 +435,13 @@ private enum Vision {
435
435
}
436
436
437
437
func callAsFunction(
438
- _ hiddenStates: MLXArray , gridThw : [ THW ] , rotaryPositionEmbedding: MLXArray
438
+ _ hiddenStates: MLXArray , frames : [ THW ] , rotaryPositionEmbedding: MLXArray
439
439
) -> MLXArray {
440
440
var hiddenStates =
441
441
hiddenStates
442
442
+ attention(
443
443
norm1 ( hiddenStates) ,
444
- gridThw : gridThw ,
444
+ frames : frames ,
445
445
rotaryPositionEmbedding: rotaryPositionEmbedding
446
446
)
447
447
hiddenStates = hiddenStates + mlp( norm2 ( hiddenStates) )
@@ -479,10 +479,10 @@ private enum Vision {
479
479
spatialMergeSize: 2 )
480
480
}
481
481
482
- func rotaryPositionEmbedding( _ gridThw : [ THW ] ) -> MLXArray {
482
+ func rotaryPositionEmbedding( _ frames : [ THW ] ) -> MLXArray {
483
483
var positionIds = [ MLXArray] ( )
484
484
485
- for row in gridThw {
485
+ for row in frames {
486
486
let ( t, h, w) = row. values
487
487
488
488
var hposIds = expandedDimensions ( MLXArray ( 0 ..< h) , axis: 1 )
@@ -516,22 +516,22 @@ private enum Vision {
516
516
}
517
517
518
518
let indices = concatenated ( positionIds, axis: 0 )
519
- let maxGridSize = gridThw . lazy. map { max ( $0. h, $0. w) } . max ( ) ?? 0
520
- let rotaryPositionEmbedFull = rotaryPositionEmbedding ( sequenceLength: maxGridSize ) [
519
+ let maxFrameSize = frames . lazy. map { max ( $0. h, $0. w) } . max ( ) ?? 0
520
+ let rotaryPositionEmbedFull = rotaryPositionEmbedding ( sequenceLength: maxFrameSize ) [
521
521
indices]
522
522
523
523
return rotaryPositionEmbedFull. reshaped ( indices. dim ( 0 ) , - 1 )
524
524
}
525
525
526
- public func callAsFunction( _ hiddenStates: MLXArray , gridThw : [ THW ] ) -> MLXArray {
526
+ public func callAsFunction( _ hiddenStates: MLXArray , frames : [ THW ] ) -> MLXArray {
527
527
var hiddenStates = patchEmbed ( hiddenStates)
528
- let rotaryPositionEmbedding = rotaryPositionEmbedding ( gridThw )
528
+ let rotaryPositionEmbedding = rotaryPositionEmbedding ( frames )
529
529
530
- let batchSize = gridThw . count
530
+ let batchSize = frames . count
531
531
532
532
for block in blocks {
533
533
hiddenStates = block (
534
- hiddenStates, gridThw : gridThw ,
534
+ hiddenStates, frames : frames ,
535
535
rotaryPositionEmbedding: rotaryPositionEmbedding)
536
536
}
537
537
@@ -585,6 +585,10 @@ private enum Vision {
585
585
/// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``.
586
586
public class Qwen2VLProcessor : UserInputProcessor {
587
587
588
+ enum Qwen2VLProcessorError : Error {
589
+ case framesIsNil
590
+ }
591
+
588
592
private let config : Qwen2VLProcessorConfiguration
589
593
private let tokenizer : any Tokenizer
590
594
@@ -686,72 +690,87 @@ public class Qwen2VLProcessor: UserInputProcessor {
686
690
return ( flattenedPatches, . init( gridT, gridH, gridW) )
687
691
}
688
692
689
- public func prepare( prompt: UserInput . Prompt , imageTHW: [ THW ] ? ) -> String {
690
- // the tokenizer does have a chat template and it expects messages
691
- // like this:
692
- //
693
- // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
694
- // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
695
- //
696
- // The output of the prompt template is fed into
697
- // image_processing_qwen2_vl.preprocess where it is further augmented
698
- // by replacing tokens according to imageTHW.
699
- //
700
- // Neither the structured content nor the postprocessing of the template
701
- // are supported in current Tokenizer/Jinja (swift) so handle that here.
702
-
703
- var messages = prompt. asMessages ( )
704
- if messages [ 0 ] [ " role " ] != " system " {
693
+ private func prepareMessages( _ messages: [ Message ] ) -> [ Message ] {
694
+ var messages = messages
695
+ print ( messages)
696
+ // Add system message if not present
697
+ if let role = messages [ 0 ] [ " role " ] as? String , role != " system " {
705
698
messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
706
699
}
707
700
708
- let lastIndex = messages. count - 1
709
- var lastMessage = messages [ lastIndex] [ " content " ] ?? " "
710
-
711
- // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
712
- let mergeLength = config. mergeSize * config. mergeSize
713
- for thw in imageTHW ?? [ ] {
714
- lastMessage += " <|vision_start|> "
715
- lastMessage += Array ( repeating: " <|image_pad|> " , count: thw. product / mergeLength)
716
- . joined ( )
717
- lastMessage += " <|vision_end|> "
718
- }
719
-
720
- messages [ lastIndex] [ " content " ] = lastMessage
721
-
722
- return
723
- messages
724
- . map {
725
- " <|im_start|> \( $0 [ " role " ] ?? " user " ) \n \( $0 [ " content " ] ?? " " ) <|im_end|> "
726
- }
727
- . joined ( separator: " \n " )
728
- + " \n <|im_start|>assistant \n "
701
+ return messages
729
702
}
730
703
704
+ // public func prepare(prompt: UserInput.Prompt, frames: [THW]?) throws -> String {
705
+ // let messages = prepareMessages(prompt.asMessages())
706
+ // let tokens = try tokenizer.applyChatTemplate(messages: messages)
707
+ // return tokenizer.decode(tokens: tokens)
708
+ // }
709
+
731
710
public func prepare( input: UserInput ) throws -> LMInput {
711
+ // Text-only input
732
712
if input. images. isEmpty {
733
- // just a straight text prompt
734
- let prompt = prepare ( prompt: input. prompt, imageTHW: nil )
735
- let promptTokens = try tokenizer. encode ( text: prompt)
713
+ let messages = input. prompt. asMessages ( )
714
+ let promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
736
715
return LMInput ( tokens: MLXArray ( promptTokens) )
737
716
}
738
717
739
- // image_processing_qwen2_vl.preprocess
740
- let images = try input. images. map {
718
+ // Input with images
719
+ let pixelsAndFrames = try input. images. map {
741
720
try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
742
721
}
743
- let pixels = concatenated ( images. map { $0. 0 } )
744
- let image = LMInput . ProcessedImage ( pixels: pixels, imageGridThw: images. map { $0. 1 } )
722
+ let pixelsConcatenated = concatenated ( pixelsAndFrames. map { $0. 0 } )
723
+
724
+ // Are the images concatenated here because they're frames of a video? How should we handle the case where multiple images are included in a multi-turn chat?
725
+ let image = LMInput . ProcessedImage (
726
+ pixels: pixelsConcatenated, frames: pixelsAndFrames. map { $0. 1 } )
727
+
728
+ // Get tokens from messages
729
+ let messages = prepareMessages ( input. prompt. asMessages ( ) )
730
+ var promptTokens = try tokenizer. applyChatTemplate ( messages: messages)
731
+
732
+ // Replace single image pad token with correct number for each image
733
+ let mergeLength = config. mergeSize * config. mergeSize
734
+
735
+ let imagePlaceholderTokens = try tokenizer. encode (
736
+ text: " <|vision_start|><|image_pad|><|vision_end|> " )
737
+
738
+ guard let frames = image. frames else {
739
+ throw Qwen2VLProcessorError . framesIsNil
740
+ }
741
+ for thw in frames {
742
+ if let padIndex = findSubsequence ( promptTokens, imagePlaceholderTokens) {
743
+ let paddingCount = thw. product / mergeLength
744
+ promptTokens. replaceSubrange (
745
+ padIndex ..< ( padIndex + imagePlaceholderTokens. count) ,
746
+ with: try tokenizer. encode (
747
+ text:
748
+ " <|vision_start|> \( Array ( repeating: " <|image_pad|> " , count: paddingCount) . joined ( ) ) <|vision_end|> "
749
+ )
750
+ )
751
+ }
752
+ }
753
+
754
+ let promptTokensDecoded = try tokenizer. decode ( tokens: promptTokens)
755
+
756
+ print ( promptTokensDecoded)
745
757
746
- // processing_qwen2_vl.Qwen2VLProcessor
747
- let prompt = prepare ( prompt: input. prompt, imageTHW: image. imageGridThw)
748
- let promptTokens = try tokenizer. encode ( text: prompt)
749
758
let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
750
759
let mask = ones ( like: promptArray) . asType ( . int8)
751
-
752
760
return LMInput ( text: . init( tokens: promptArray, mask: mask) , image: image)
753
761
}
754
762
763
+ private func findSubsequence( _ array: [ Int ] , _ subsequence: [ Int ] ) -> Int ? {
764
+ guard subsequence. count <= array. count else { return nil }
765
+
766
+ for i in 0 ... ( array. count - subsequence. count) {
767
+ if Array ( array [ i ..< ( i + subsequence. count) ] ) == subsequence {
768
+ return i
769
+ }
770
+ }
771
+ return nil
772
+ }
773
+
755
774
}
756
775
757
776
// MARK: - Model
@@ -779,18 +798,18 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
779
798
self . _languageModel. wrappedValue = Language . LanguageModel ( config. textConfiguration)
780
799
}
781
800
782
- private func inputEmbeddings( inputIds: MLXArray , pixelValues: MLXArray ? , gridThw : [ THW ] ? )
801
+ private func inputEmbeddings( inputIds: MLXArray , pixelValues: MLXArray ? , frames : [ THW ] ? )
783
802
-> MLXArray
784
803
{
785
- guard let pixelValues, let gridThw else {
804
+ guard let pixelValues, let frames else {
786
805
return languageModel ( inputIds) . logits
787
806
}
788
807
789
808
// Get the input embeddings from the language model
790
809
let inputEmbeds = languageModel. model. embedTokens ( inputIds)
791
810
792
811
// Get the ouptut hidden states from the vision model
793
- var hiddenStates = self . visionModel ( pixelValues, gridThw : gridThw )
812
+ var hiddenStates = self . visionModel ( pixelValues, frames : frames )
794
813
795
814
if hiddenStates. ndim == 2 {
796
815
hiddenStates = hiddenStates [ . newAxis, 0 ... , 0 ... ]
@@ -820,13 +839,13 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
820
839
public func prepare( _ input: LMInput , cache: [ any KVCache ] , windowSize: Int ? ) throws
821
840
-> PrepareResult
822
841
{
823
- let gridThw = input. image? . imageGridThw
842
+ let frames = input. image? . frames
824
843
825
844
let dtype = visionModel. patchEmbed. proj. weight. dtype
826
845
let pixels = input. image? . pixels. asType ( dtype)
827
846
828
847
let inputEmbeddings = self . inputEmbeddings (
829
- inputIds: input. text. tokens, pixelValues: pixels, gridThw : gridThw )
848
+ inputIds: input. text. tokens, pixelValues: pixels, frames : frames )
830
849
831
850
let result = languageModel ( nil , cache: cache, inputEmbedding: inputEmbeddings)
832
851
0 commit comments