Skip to content

Commit f1f67bf

Browse files
committed
Working on integrating videos
1 parent 408e7a8 commit f1f67bf

File tree

5 files changed

+94
-142
lines changed

5 files changed

+94
-142
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,17 +383,33 @@ class VLMEvaluator {
383383
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
384384

385385
let result = try await modelContainer.perform { context in
386-
let videos: [UserInput.Video] = videoURL != nil ? [.url(videoURL!)] : []
386+
let images: [UserInput.Image] =
387+
if let image {
388+
[UserInput.Image.ciImage(image)]
389+
} else {
390+
[]
391+
}
392+
let videos: [UserInput.Video] =
393+
if let videoURL {
394+
[.url(videoURL)]
395+
} else {
396+
[]
397+
}
387398
var userInput = UserInput(
388399
messages: [
389400
[
390401
"role": "user",
391402
"content": [
392-
["type": "text", "text": prompt],
393-
["type": "image"],
394-
],
403+
["type": "text", "text": prompt]
404+
]
405+
+ images.map { _ in
406+
["type": "image"]
407+
}
408+
+ videos.map { _ in
409+
["type": "video"]
410+
},
395411
]
396-
], images: [.ciImage(image)], videos: videos)
412+
], images: images, videos: videos)
397413
userInput.processing.resize = .init(width: 448, height: 448)
398414

399415
let input = try await context.processor.prepare(input: userInput)

Libraries/MLXLMCommon/LanguageModel.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ public struct LMInput {
6969
/// Representation of prepared input image(s).
7070
public struct ProcessedImage {
7171

72+
/// Concatenated pixels from one or more images
7273
public let pixels: MLXArray
74+
/// Time, height, and width of the images
7375
public let frames: [THW]?
7476

7577
public init(
@@ -85,13 +87,13 @@ public struct LMInput {
8587
public struct ProcessedVideo {
8688

8789
public let pixels: MLXArray
88-
public let videoGridThw: [THW]?
90+
public let frames: [THW]?
8991

9092
public init(
9193
pixels: MLXArray, videoGridThw: [THW]? = nil
9294
) {
9395
self.pixels = pixels
94-
self.videoGridThw = videoGridThw
96+
self.frames = videoGridThw
9597
}
9698
}
9799

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,10 @@ public struct UserInput: Sendable {
134134
self.videos = videos
135135
}
136136

137-
public init(messages: [Message], images: [Image] = [Image]()) {
137+
public init(messages: [Message], images: [Image] = [Image](), videos: [Video] = [Video]()) {
138138
self.prompt = .messages(messages)
139139
self.images = images
140+
self.videos = videos
140141
}
141142

142143
public init(prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init()) {

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 66 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -694,147 +694,83 @@ public class Qwen2VLProcessor: UserInputProcessor {
694694
return (flattenedPatches, .init(gridT, gridH, gridW))
695695
}
696696

697-
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?, videoTHW: [THW]?) -> String {
698-
// the tokenizer does have a chat template and it expects messages
699-
// like this:
700-
//
701-
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
702-
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
703-
//
704-
// The output of the prompt template is fed into
705-
// image_processing_qwen2_vl.preprocess where it is further augmented
706-
// by replacing tokens according to imageTHW.
707-
//
708-
// Neither the structured content nor the postprocessing of the template
709-
// are supported in current Tokenizer/Jinja (swift) so handle that here.
710-
711-
var messages = prompt.asMessages()
712-
if messages[0]["role"] != "system" {
713-
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
697+
public func prepare(input: UserInput) async throws -> LMInput {
698+
let messages = input.prompt.asMessages()
699+
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
700+
// Text-only input
701+
if input.images.isEmpty, input.videos.isEmpty {
702+
return LMInput(tokens: MLXArray(promptTokens))
714703
}
715-
716-
let lastIndex = messages.count - 1
717-
var lastMessage = messages[lastIndex]["content"] ?? ""
718-
719-
// image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
720-
let mergeLength = config.mergeSize * config.mergeSize
721-
for thw in imageTHW ?? [] {
722-
lastMessage += "<|vision_start|>"
723-
lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
724-
.joined()
725-
lastMessage += "<|vision_end|>"
704+
// Input with images and/or videos
705+
// Image processing
706+
let imagePixelsAndFrames = try input.images.map {
707+
try preprocess(images: [$0.asCIImage()], processing: input.processing)
726708
}
727-
728-
for thw in videoTHW ?? [] {
729-
lastMessage += "<|vision_start|>"
730-
lastMessage += Array(repeating: "<|video_pad|>", count: thw.product / mergeLength)
731-
.joined()
732-
lastMessage += "<|vision_end|>"
709+
let processedImage: LMInput.ProcessedImage?
710+
if !imagePixelsAndFrames.isEmpty {
711+
let imagePixelsConcatenated = concatenated(imagePixelsAndFrames.map { $0.0 })
712+
processedImage = LMInput.ProcessedImage(
713+
pixels: imagePixelsConcatenated, frames: imagePixelsAndFrames.map { $0.1 })
714+
if let imageFrames = processedImage?.frames {
715+
// Replace padding for images
716+
promptTokens = try replacePlaceholderTokens(
717+
in: promptTokens, frames: imageFrames, paddingToken: "<|image_pad|>")
718+
}
719+
} else {
720+
processedImage = nil
733721
}
734-
735-
messages[lastIndex]["content"] = lastMessage
736-
737-
return
738-
messages
739-
.map {
740-
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
722+
// Video processing
723+
var videosAsImageSequences = [[CIImage]]()
724+
for video in input.videos {
725+
if let imageSequence = try? await MediaProcessing.asCIImageSequence(
726+
video.asAVAsset(), samplesPerSecond: 2)
727+
{
728+
videosAsImageSequences.append(imageSequence)
741729
}
742-
.joined(separator: "\n")
743-
+ "\n<|im_start|>assistant\n"
744-
}
745-
746-
private func prepareMessages(_ messages: [Message]) -> [Message] {
747-
var messages = messages
748-
// Add system message if not present
749-
if let role = messages[0]["role"] as? String, role != "system" {
750-
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
751730
}
752-
return messages
753-
}
754-
755-
// public func prepare(prompt: UserInput.Prompt, frames: [THW]?) throws -> String {
756-
// let messages = prepareMessages(prompt.asMessages())
757-
// let tokens = try tokenizer.applyChatTemplate(messages: messages)
758-
// return tokenizer.decode(tokens: tokens)
759-
// }
760-
761-
public func prepare(input: UserInput) throws -> LMInput {
762-
// Text-only input
763-
if input.images.isEmpty {
764-
let messages = input.prompt.asMessages()
765-
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
766-
return LMInput(tokens: MLXArray(promptTokens))
731+
let videoPixelsAndFrames = try videosAsImageSequences.map {
732+
try preprocess(images: $0, processing: input.processing)
767733
}
768-
// Input with images
769-
let pixelsAndFrames = try input.images.map {
770-
try preprocess(images: [$0.asCIImage()], processing: input.processing)
734+
let processedVideo: LMInput.ProcessedVideo?
735+
if !videoPixelsAndFrames.isEmpty {
736+
let videoPixelsConcatenated = concatenated(videoPixelsAndFrames.map { $0.0 })
737+
processedVideo = LMInput.ProcessedVideo(
738+
pixels: videoPixelsConcatenated, videoGridThw: videoPixelsAndFrames.map { $0.1 })
739+
if let videoFrames = processedVideo?.frames {
740+
promptTokens = try replacePlaceholderTokens(
741+
in: promptTokens, frames: videoFrames, paddingToken: "<|video_pad|>")
742+
}
743+
} else {
744+
processedVideo = nil
771745
}
746+
//
747+
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
748+
let mask = ones(like: promptArray).asType(.int8)
749+
return LMInput(
750+
text: .init(tokens: promptArray, mask: mask), image: processedImage,
751+
video: processedVideo)
752+
}
772753

773-
// var videosAsImageSequences = [[CIImage]]()
774-
// for video in input.videos {
775-
// if let imageSequence = try? await MediaProcessing.asCIImageSequence(
776-
// video.asAVAsset(), samplesPerSecond: 2)
777-
// {
778-
// videosAsImageSequences.append(imageSequence)
779-
// }
780-
// }
781-
// let videos = try videosAsImageSequences.map {
782-
// try preprocess(images: $0, processing: input.processing)
783-
// }
784-
785-
// let imagePixels: MLXArray?
786-
// let image: LMInput.ProcessedImage?
787-
// if !images.isEmpty {
788-
// imagePixels = concatenated(images.map { $0.0 })
789-
// image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 })
790-
// } else {
791-
// imagePixels = nil
792-
// image = nil
793-
// }
794-
795-
// let videoPixels: MLXArray?
796-
// let video: LMInput.ProcessedVideo?
797-
// if !videos.isEmpty {
798-
// videoPixels = concatenated(videos.map { $0.0 })
799-
// video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 })
800-
// } else {
801-
// videoPixels = nil
802-
// video = nil
803-
// }
804-
805-
// // processing_qwen2_vl.Qwen2VLProcessor
806-
// let prompt = prepare(
807-
// prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw)
808-
// let promptTokens = try tokenizer.encode(text: prompt)
809-
// let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
810-
// let mask = ones(like: promptArray).asType(.int8)
811-
812-
// return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video)
813-
let pixelsConcatenated = concatenated(pixelsAndFrames.map { $0.0 })
814-
let image = LMInput.ProcessedImage(
815-
pixels: pixelsConcatenated, frames: pixelsAndFrames.map { $0.1 })
816-
let messages = prepareMessages(input.prompt.asMessages())
817-
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
818-
// Replace single image pad token with correct number for each image
819-
let mergeLength = config.mergeSize * config.mergeSize
820-
let imagePlaceholderTokens = try tokenizer.encode(
821-
text: "<|vision_start|><|image_pad|><|vision_end|>")
822-
guard let frames = image.frames else {
823-
throw Qwen2VLProcessorError.framesIsNil
824-
}
825-
let placeholderRanges = promptTokens.ranges(of: imagePlaceholderTokens)
754+
func replacePlaceholderTokens(in promptTokens: [Int], frames: [THW], paddingToken: String)
755+
throws -> [Int]
756+
{
757+
// Replace single padding token with correct number for each image
758+
let placeholderTokens = try tokenizer.encode(
759+
text: "<|vision_start|>\(paddingToken)<|vision_end|>")
760+
let placeholderRanges = promptTokens.ranges(of: placeholderTokens)
826761
guard placeholderRanges.count == frames.count else {
827762
throw VLMError.processing(
828-
"Number of image placeholders does not match number of frames")
763+
"Number of placeholder tokens does not match number of frames")
829764
}
830-
let replacementSequences = try frames.map { thw in
831-
let paddingCount = thw.product / mergeLength
765+
let mergeLength = config.mergeSize * config.mergeSize
766+
let replacementSequences = try frames.map { frame in
767+
let paddingCount = frame.product / mergeLength
832768
return try tokenizer.encode(
833769
text:
834-
"<|vision_start|>\(Array(repeating: "<|image_pad|>", count: paddingCount).joined())<|vision_end|>"
770+
"<|vision_start|>\(Array(repeating: paddingToken, count: paddingCount).joined())<|vision_end|>"
835771
)
836772
}
837-
// Build the final array
773+
// Build the final array (images)
838774
var result: [Int] = []
839775
var currentIndex = promptTokens.startIndex
840776
for (range, replacement) in zip(placeholderRanges, replacementSequences) {
@@ -848,10 +784,7 @@ public class Qwen2VLProcessor: UserInputProcessor {
848784
if currentIndex < promptTokens.endIndex {
849785
result.append(contentsOf: promptTokens[currentIndex...])
850786
}
851-
promptTokens = result
852-
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
853-
let mask = ones(like: promptArray).asType(.int8)
854-
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
787+
return result
855788
}
856789
}
857790

@@ -934,17 +867,17 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
934867

935868
let dtype = visionModel.patchEmbed.proj.weight.dtype
936869

937-
let imageGridThw = input.image?.imageGridThw
870+
let imageFrames = input.image?.frames
938871
let imagePixels = input.image?.pixels.asType(dtype)
939872

940-
let videoGridThw = input.video?.videoGridThw
873+
let videoGridThw = input.video?.frames
941874
let videoPixels = input.video?.pixels.asType(dtype)
942875

943876
let gridThw: [THW]?
944877
let pixels: MLXArray?
945878

946879
if videoGridThw == nil {
947-
gridThw = imageGridThw
880+
gridThw = imageFrames
948881
pixels = imagePixels
949882
} else {
950883
gridThw = videoGridThw

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ let package = Package(
2929
dependencies: [
3030
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.21.2")),
3131
.package(
32-
url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.16")
32+
url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.17")
3333
),
3434
.package(
3535
url: "https://github.com/apple/swift-async-algorithms", .upToNextMinor(from: "1.0.0")),

0 commit comments

Comments
 (0)