Skip to content

Commit 3864824

Browse files
authored
Add video support for Qwen2-VL (#187)
* Implement video support for Qwen 2 VL * Fix formatting and color space conversion
1 parent 7109e3e commit 3864824

File tree

5 files changed

+265
-33
lines changed

5 files changed

+265
-33
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 97 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright 2024 Apple Inc.
22

3+
import AVKit
34
import CoreImage
45
import MLX
56
import MLXLMCommon
@@ -19,12 +20,28 @@ struct ContentView: View {
1920
@State var llm = VLMEvaluator()
2021
@Environment(DeviceStat.self) private var deviceStat
2122

22-
@State private var selectedImage: PlatformImage? = nil
23+
@State private var selectedImage: PlatformImage? = nil {
24+
didSet {
25+
if selectedImage != nil {
26+
selectedVideoURL = nil
27+
player = nil
28+
}
29+
}
30+
}
31+
@State private var selectedVideoURL: URL? = nil {
32+
didSet {
33+
if let selectedVideoURL {
34+
player = AVPlayer(url: selectedVideoURL)
35+
selectedImage = nil
36+
}
37+
}
38+
}
2339
@State private var showingImagePicker = false
2440
@State private var selectedItem: PhotosPickerItem? = nil
41+
@State private var player: AVPlayer? = nil
2542

2643
private var currentImageURL: URL? {
27-
selectedImage == nil
44+
selectedImage == nil && selectedVideoURL == nil
2845
? URL(
2946
string:
3047
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
@@ -74,40 +91,60 @@ struct ContentView: View {
7491
EmptyView()
7592
}
7693
}
94+
} else if let player {
95+
VideoPlayer(player: player)
96+
.scaledToFit()
97+
.frame(maxHeight: 300)
98+
.cornerRadius(12)
7799
}
78100

79101
HStack {
80102
#if os(iOS)
81103
PhotosPicker(
82104
selection: $selectedItem,
83-
matching: .images
105+
matching: PHPickerFilter.any(of: [
106+
PHPickerFilter.images, PHPickerFilter.videos,
107+
])
84108
) {
85-
Label("Select Image", systemImage: "photo.badge.plus")
109+
Label("Select Image/Video", systemImage: "photo.badge.plus")
86110
}
87111
.onChange(of: selectedItem) {
88112
Task {
89-
if let data = try? await selectedItem?.loadTransferable(
113+
if let video = try? await selectedItem?.loadTransferable(
114+
type: TransferableVideo.self)
115+
{
116+
selectedVideoURL = video.url
117+
} else if let data = try? await selectedItem?.loadTransferable(
90118
type: Data.self)
91119
{
92120
selectedImage = PlatformImage(data: data)
93121
}
94122
}
95123
}
96124
#else
97-
Button("Select Image") {
125+
Button("Select Image/Video") {
98126
showingImagePicker = true
99127
}
100128
.fileImporter(
101129
isPresented: $showingImagePicker,
102-
allowedContentTypes: [.image]
130+
allowedContentTypes: [.image, .movie]
103131
) { result in
104132
switch result {
105133
case .success(let file):
106134
Task { @MainActor in
107135
do {
108-
let data = try loadImage(from: file)
136+
let data = try loadData(from: file)
109137
if let image = PlatformImage(data: data) {
110138
selectedImage = image
139+
} else if let fileType = UTType(
140+
filenameExtension: file.pathExtension),
141+
fileType.conforms(to: .movie)
142+
{
143+
if let sandboxURL = try? loadVideoToSandbox(
144+
from: file)
145+
{
146+
selectedVideoURL = sandboxURL
147+
}
111148
} else {
112149
print("Failed to create image from data")
113150
}
@@ -214,30 +251,34 @@ struct ContentView: View {
214251
if let selectedImage = selectedImage {
215252
#if os(iOS)
216253
let ciImage = CIImage(image: selectedImage)
217-
await llm.generate(prompt: prompt, image: ciImage ?? CIImage())
254+
await llm.generate(prompt: prompt, image: ciImage ?? CIImage(), videoURL: nil)
218255
#else
219256
if let cgImage = selectedImage.cgImage(
220257
forProposedRect: nil, context: nil, hints: nil)
221258
{
222259
let ciImage = CIImage(cgImage: cgImage)
223-
await llm.generate(prompt: prompt, image: ciImage)
260+
await llm.generate(prompt: prompt, image: ciImage, videoURL: nil)
224261
}
225262
#endif
226263
} else if let imageURL = currentImageURL {
227264
do {
228265
let (data, _) = try await URLSession.shared.data(from: imageURL)
229266
if let ciImage = CIImage(data: data) {
230-
await llm.generate(prompt: prompt, image: ciImage)
267+
await llm.generate(prompt: prompt, image: ciImage, videoURL: nil)
231268
}
232269
} catch {
233270
print("Failed to load image: \(error.localizedDescription)")
234271
}
272+
} else {
273+
if let videoURL = selectedVideoURL {
274+
await llm.generate(prompt: prompt, image: nil, videoURL: videoURL)
275+
}
235276
}
236277
}
237278
}
238279

239280
#if os(macOS)
240-
private func loadImage(from url: URL) throws -> Data {
281+
private func loadData(from url: URL) throws -> Data {
241282
guard url.startAccessingSecurityScopedResource() else {
242283
throw NSError(
243284
domain: "FileAccess", code: -1,
@@ -246,6 +287,17 @@ struct ContentView: View {
246287
defer { url.stopAccessingSecurityScopedResource() }
247288
return try Data(contentsOf: url)
248289
}
290+
291+
private func loadVideoToSandbox(from url: URL) throws -> URL {
292+
guard url.startAccessingSecurityScopedResource() else {
293+
throw NSError(
294+
domain: "FileAccess", code: -1,
295+
userInfo: [NSLocalizedDescriptionKey: "Failed to access the file."])
296+
}
297+
defer { url.stopAccessingSecurityScopedResource() }
298+
let sandboxURL = try SandboxFileTransfer.transferFileToTemp(from: url)
299+
return sandboxURL
300+
}
249301
#endif
250302

251303
private func copyToClipboard(_ string: String) {
@@ -318,7 +370,7 @@ class VLMEvaluator {
318370
}
319371
}
320372

321-
func generate(prompt: String, image: CIImage) async {
373+
func generate(prompt: String, image: CIImage?, videoURL: URL?) async {
322374
guard !running else { return }
323375

324376
running = true
@@ -331,7 +383,9 @@ class VLMEvaluator {
331383
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
332384

333385
let result = try await modelContainer.perform { context in
334-
var userInput = UserInput(prompt: prompt, images: [.ciImage(image)])
386+
let images: [UserInput.Image] = image != nil ? [.ciImage(image!)] : []
387+
let videos: [UserInput.Video] = videoURL != nil ? [.url(videoURL!)] : []
388+
var userInput = UserInput(prompt: prompt, images: images, videos: videos)
335389
userInput.processing.resize = .init(width: 448, height: 448)
336390

337391
let input = try await context.processor.prepare(input: userInput)
@@ -370,3 +424,32 @@ class VLMEvaluator {
370424
running = false
371425
}
372426
}
427+
428+
#if os(iOS)
429+
struct TransferableVideo: Transferable {
430+
let url: URL
431+
432+
static var transferRepresentation: some TransferRepresentation {
433+
FileRepresentation(contentType: .movie) { movie in
434+
SentTransferredFile(movie.url)
435+
} importing: { received in
436+
let sandboxURL = try SandboxFileTransfer.transferFileToTemp(from: received.file)
437+
return .init(url: sandboxURL)
438+
}
439+
}
440+
}
441+
#endif
442+
443+
struct SandboxFileTransfer {
444+
static func transferFileToTemp(from sourceURL: URL) throws -> URL {
445+
let tempDir = FileManager.default.temporaryDirectory
446+
let sandboxURL = tempDir.appendingPathComponent(sourceURL.lastPathComponent)
447+
448+
if FileManager.default.fileExists(atPath: sandboxURL.path()) {
449+
try FileManager.default.removeItem(at: sandboxURL)
450+
}
451+
452+
try FileManager.default.copyItem(at: sourceURL, to: sandboxURL)
453+
return sandboxURL
454+
}
455+
}

Libraries/MLXLMCommon/LanguageModel.swift

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public struct THW: Sendable {
3737
public struct LMInput {
3838
public let text: Text
3939
public let image: ProcessedImage?
40+
public let video: ProcessedVideo?
4041

4142
/// Representation of tokenized input text.
4243
public struct Text {
@@ -79,13 +80,32 @@ public struct LMInput {
7980
}
8081
}
8182

83+
/// Representation of prepared input video(s).
84+
/// For now, this is virtually identical to ProcessedImage.
85+
public struct ProcessedVideo {
86+
87+
public let pixels: MLXArray
88+
public let videoGridThw: [THW]?
89+
90+
public init(
91+
pixels: MLXArray, videoGridThw: [THW]? = nil
92+
) {
93+
self.pixels = pixels
94+
self.videoGridThw = videoGridThw
95+
}
96+
}
97+
8298
public init(tokens: MLXArray, mask: MLXArray? = nil) {
8399
self.init(text: .init(tokens: tokens, mask: mask))
84100
}
85101

86-
public init(text: LMInput.Text, image: LMInput.ProcessedImage? = nil) {
102+
public init(
103+
text: LMInput.Text, image: LMInput.ProcessedImage? = nil,
104+
video: LMInput.ProcessedVideo? = nil
105+
) {
87106
self.text = text
88107
self.image = image
108+
self.video = video
89109
}
90110
}
91111

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright © 2024 Apple Inc.
22

3+
import AVFoundation
34
import CoreImage
45
import Foundation
56
import MLX
@@ -34,6 +35,20 @@ public struct UserInput: Sendable {
3435
}
3536
}
3637

38+
public enum Video: Sendable {
39+
case avAsset(AVAsset)
40+
case url(URL)
41+
42+
public func asAVAsset() -> AVAsset {
43+
switch self {
44+
case .avAsset(let asset):
45+
return asset
46+
case .url(let url):
47+
return AVAsset(url: url)
48+
}
49+
}
50+
}
51+
3752
/// Representation of a single image.
3853
public enum Image: Sendable {
3954
case ciImage(CIImage)
@@ -109,11 +124,13 @@ public struct UserInput: Sendable {
109124

110125
public var prompt: Prompt
111126
public var images = [Image]()
127+
public var videos = [Video]()
112128
public var processing: Processing = .init()
113129

114-
public init(prompt: String, images: [Image] = [Image]()) {
130+
public init(prompt: String, images: [Image] = [Image](), videos: [Video] = [Video]()) {
115131
self.prompt = .text(prompt)
116132
self.images = images
133+
self.videos = videos
117134
}
118135

119136
public init(messages: [[String: String]], images: [Image] = [Image]()) {

Libraries/MLXVLM/MediaProcessing.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright © 2024 Apple Inc.
22

3+
import AVFoundation
34
import CoreImage.CIFilterBuiltins
45
import MLX
56
import MLXLMCommon
@@ -154,4 +155,48 @@ public enum MediaProcessing {
154155

155156
return image
156157
}
158+
159+
static func asCIImageSequence(_ asset: AVAsset, samplesPerSecond: Int) async throws -> [CIImage]
160+
{
161+
// Use AVAssetImageGenerator to extract frames
162+
let generator = AVAssetImageGenerator(asset: asset)
163+
generator.appliesPreferredTrackTransform = true
164+
generator.requestedTimeToleranceBefore = .zero
165+
generator.requestedTimeToleranceAfter = .zero
166+
167+
// Calculate the time values we want to sample
168+
guard let duration = try? await asset.load(.duration) else {
169+
throw NSError(
170+
domain: "MediaProcessing", code: -1,
171+
userInfo: [NSLocalizedDescriptionKey: "Failed to load the asset's duration"])
172+
}
173+
174+
let durationInSeconds = duration.seconds
175+
let samplesPerSecond = Double(samplesPerSecond)
176+
let secondsPerSample = 1.0 / samplesPerSecond
177+
let totalFramesToSample = durationInSeconds * samplesPerSecond
178+
let durationTimeValue = duration.value
179+
let sampledTimeValues = MLXArray.linspace(
180+
0, durationTimeValue, count: Int(totalFramesToSample)
181+
).asArray(Int64.self)
182+
183+
// Construct a CMTime using the sampled CMTimeValue's and the asset's timescale
184+
let timescale = duration.timescale
185+
let sampledTimes = sampledTimeValues.map { CMTime(value: $0, timescale: timescale) }
186+
187+
// Collect the frames
188+
var ciImages: [CIImage] = []
189+
for await result in await generator.images(for: sampledTimes) {
190+
switch result {
191+
case .success(requestedTime: let requested, let image, actualTime: let actual):
192+
let ciImage = CIImage(
193+
cgImage: image, options: [.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!])
194+
ciImages.append(ciImage)
195+
case .failure(requestedTime: let requested, let error):
196+
break
197+
}
198+
}
199+
200+
return ciImages
201+
}
157202
}

0 commit comments

Comments
 (0)