Skip to content

Commit 408e7a8

Browse files
committed
Use chat template for Qwen 2 VL
1 parent 0065df8 commit 408e7a8

File tree

8 files changed

+142
-71
lines changed

8 files changed

+142
-71
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,17 @@ class VLMEvaluator {
383383
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
384384

385385
let result = try await modelContainer.perform { context in
386-
let images: [UserInput.Image] = image != nil ? [.ciImage(image!)] : []
387386
let videos: [UserInput.Video] = videoURL != nil ? [.url(videoURL!)] : []
388-
var userInput = UserInput(prompt: prompt, images: images, videos: videos)
387+
var userInput = UserInput(
388+
messages: [
389+
[
390+
"role": "user",
391+
"content": [
392+
["type": "text", "text": prompt],
393+
["type": "image"],
394+
],
395+
]
396+
], images: [.ciImage(image)], videos: videos)
389397
userInput.processing.resize = .init(width: 448, height: 448)
390398

391399
let input = try await context.processor.prepare(input: userInput)

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
236236
// but that is not public so just fall back to text
237237
let prompt = input.prompt
238238
.asMessages()
239-
.compactMap { $0["content"] }
239+
.compactMap { $0["content"] as? String }
240240
.joined(separator: ". ")
241241
let promptTokens = tokenizer.encode(text: prompt)
242242
return LMInput(tokens: MLXArray(promptTokens))

Libraries/MLXLMCommon/LanguageModel.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ public struct LMInput {
7070
public struct ProcessedImage {
7171

7272
public let pixels: MLXArray
73-
public let imageGridThw: [THW]?
73+
public let frames: [THW]?
7474

7575
public init(
76-
pixels: MLXArray, imageGridThw: [THW]? = nil
76+
pixels: MLXArray, frames: [THW]? = nil
7777
) {
7878
self.pixels = pixels
79-
self.imageGridThw = imageGridThw
79+
self.frames = frames
8080
}
8181
}
8282

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@ import CoreImage
55
import Foundation
66
import MLX
77

8+
public typealias Message = [String: Any]
9+
810
/// Container for raw user input.
911
///
1012
/// A ``UserInputProcessor`` can convert this to ``LMInput``.
1113
/// See also ``ModelContext``.
1214
public struct UserInput: Sendable {
13-
1415
/// Representation of a prompt or series of messages (conversation).
1516
public enum Prompt: Sendable, CustomStringConvertible {
1617
case text(String)
17-
case messages([[String: String]])
18+
case messages([Message])
1819

19-
public func asMessages() -> [[String: String]] {
20+
public func asMessages() -> [Message] {
2021
switch self {
2122
case .text(let text):
2223
return [["role": "user", "content": text]]
@@ -133,7 +134,7 @@ public struct UserInput: Sendable {
133134
self.videos = videos
134135
}
135136

136-
public init(messages: [[String: String]], images: [Image] = [Image]()) {
137+
public init(messages: [Message], images: [Image] = [Image]()) {
137138
self.prompt = .messages(messages)
138139
self.images = images
139140
}

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor {
805805
}
806806

807807
public func prepare(input: UserInput) throws -> LMInput {
808-
let prompt = input.prompt.asMessages().last?["content"] ?? ""
808+
let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
809809

810810
if input.images.isEmpty {
811811
// No image scenario

Libraries/MLXVLM/Models/Paligemma.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
478478
}
479479

480480
// this doesn't have a chat template so just use the last message.
481-
var prompt = input.prompt.asMessages().last?["content"] ?? ""
481+
var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
482482

483483
// based on transformers/processing_paligemma
484484
let count = input.images.count * config.imageSequenceLength

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 120 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ private enum Vision {
367367
}
368368

369369
public func callAsFunction(
370-
_ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
370+
_ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray
371371
) -> MLXArray {
372372
let sequenceLength = x.dim(0)
373-
let B = gridThw[0].t
373+
let B = frames[0].t
374374
let L = sequenceLength / B
375375

376376
let qkv = qkv(x)
@@ -435,13 +435,13 @@ private enum Vision {
435435
}
436436

437437
func callAsFunction(
438-
_ hiddenStates: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
438+
_ hiddenStates: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray
439439
) -> MLXArray {
440440
var hiddenStates =
441441
hiddenStates
442442
+ attention(
443443
norm1(hiddenStates),
444-
gridThw: gridThw,
444+
frames: frames,
445445
rotaryPositionEmbedding: rotaryPositionEmbedding
446446
)
447447
hiddenStates = hiddenStates + mlp(norm2(hiddenStates))
@@ -479,10 +479,10 @@ private enum Vision {
479479
spatialMergeSize: 2)
480480
}
481481

482-
func rotaryPositionEmbedding(_ gridThw: [THW]) -> MLXArray {
482+
func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray {
483483
var positionIds = [MLXArray]()
484484

485-
for row in gridThw {
485+
for row in frames {
486486
let (t, h, w) = row.values
487487

488488
var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1)
@@ -516,22 +516,22 @@ private enum Vision {
516516
}
517517

518518
let indices = concatenated(positionIds, axis: 0)
519-
let maxGridSize = gridThw.lazy.map { max($0.h, $0.w) }.max() ?? 0
520-
let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[
519+
let maxFrameSize = frames.lazy.map { max($0.h, $0.w) }.max() ?? 0
520+
let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxFrameSize)[
521521
indices]
522522

523523
return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1)
524524
}
525525

526-
public func callAsFunction(_ hiddenStates: MLXArray, gridThw: [THW]) -> MLXArray {
526+
public func callAsFunction(_ hiddenStates: MLXArray, frames: [THW]) -> MLXArray {
527527
var hiddenStates = patchEmbed(hiddenStates)
528-
let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw)
528+
let rotaryPositionEmbedding = rotaryPositionEmbedding(frames)
529529

530-
let batchSize = gridThw.count
530+
let batchSize = frames.count
531531

532532
for block in blocks {
533533
hiddenStates = block(
534-
hiddenStates, gridThw: gridThw,
534+
hiddenStates, frames: frames,
535535
rotaryPositionEmbedding: rotaryPositionEmbedding)
536536
}
537537

@@ -585,6 +585,10 @@ private enum Vision {
585585
/// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``.
586586
public class Qwen2VLProcessor: UserInputProcessor {
587587

588+
enum Qwen2VLProcessorError: Error {
589+
case framesIsNil
590+
}
591+
588592
private let config: Qwen2VLProcessorConfiguration
589593
private let tokenizer: any Tokenizer
590594

@@ -739,61 +743,116 @@ public class Qwen2VLProcessor: UserInputProcessor {
739743
+ "\n<|im_start|>assistant\n"
740744
}
741745

742-
public func prepare(input: UserInput) async throws -> LMInput {
743-
if input.images.isEmpty && input.videos.isEmpty {
744-
// just a straight text prompt
745-
let prompt = prepare(prompt: input.prompt, imageTHW: nil, videoTHW: nil)
746-
let promptTokens = try tokenizer.encode(text: prompt)
747-
return LMInput(tokens: MLXArray(promptTokens))
746+
private func prepareMessages(_ messages: [Message]) -> [Message] {
747+
var messages = messages
748+
// Add system message if not present
749+
if let role = messages[0]["role"] as? String, role != "system" {
750+
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
748751
}
752+
return messages
753+
}
749754

750-
// image_processing_qwen2_vl.preprocess
751-
let images = try input.images.map {
755+
// public func prepare(prompt: UserInput.Prompt, frames: [THW]?) throws -> String {
756+
// let messages = prepareMessages(prompt.asMessages())
757+
// let tokens = try tokenizer.applyChatTemplate(messages: messages)
758+
// return tokenizer.decode(tokens: tokens)
759+
// }
760+
761+
public func prepare(input: UserInput) throws -> LMInput {
762+
// Text-only input
763+
if input.images.isEmpty {
764+
let messages = input.prompt.asMessages()
765+
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
766+
return LMInput(tokens: MLXArray(promptTokens))
767+
}
768+
// Input with images
769+
let pixelsAndFrames = try input.images.map {
752770
try preprocess(images: [$0.asCIImage()], processing: input.processing)
753771
}
754772

755-
var videosAsImageSequences = [[CIImage]]()
756-
for video in input.videos {
757-
if let imageSequence = try? await MediaProcessing.asCIImageSequence(
758-
video.asAVAsset(), samplesPerSecond: 2)
759-
{
760-
videosAsImageSequences.append(imageSequence)
761-
}
773+
// var videosAsImageSequences = [[CIImage]]()
774+
// for video in input.videos {
775+
// if let imageSequence = try? await MediaProcessing.asCIImageSequence(
776+
// video.asAVAsset(), samplesPerSecond: 2)
777+
// {
778+
// videosAsImageSequences.append(imageSequence)
779+
// }
780+
// }
781+
// let videos = try videosAsImageSequences.map {
782+
// try preprocess(images: $0, processing: input.processing)
783+
// }
784+
785+
// let imagePixels: MLXArray?
786+
// let image: LMInput.ProcessedImage?
787+
// if !images.isEmpty {
788+
// imagePixels = concatenated(images.map { $0.0 })
789+
// image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 })
790+
// } else {
791+
// imagePixels = nil
792+
// image = nil
793+
// }
794+
795+
// let videoPixels: MLXArray?
796+
// let video: LMInput.ProcessedVideo?
797+
// if !videos.isEmpty {
798+
// videoPixels = concatenated(videos.map { $0.0 })
799+
// video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 })
800+
// } else {
801+
// videoPixels = nil
802+
// video = nil
803+
// }
804+
805+
// // processing_qwen2_vl.Qwen2VLProcessor
806+
// let prompt = prepare(
807+
// prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw)
808+
// let promptTokens = try tokenizer.encode(text: prompt)
809+
// let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
810+
// let mask = ones(like: promptArray).asType(.int8)
811+
812+
// return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video)
813+
let pixelsConcatenated = concatenated(pixelsAndFrames.map { $0.0 })
814+
let image = LMInput.ProcessedImage(
815+
pixels: pixelsConcatenated, frames: pixelsAndFrames.map { $0.1 })
816+
let messages = prepareMessages(input.prompt.asMessages())
817+
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
818+
// Replace single image pad token with correct number for each image
819+
let mergeLength = config.mergeSize * config.mergeSize
820+
let imagePlaceholderTokens = try tokenizer.encode(
821+
text: "<|vision_start|><|image_pad|><|vision_end|>")
822+
guard let frames = image.frames else {
823+
throw Qwen2VLProcessorError.framesIsNil
762824
}
763-
let videos = try videosAsImageSequences.map {
764-
try preprocess(images: $0, processing: input.processing)
825+
let placeholderRanges = promptTokens.ranges(of: imagePlaceholderTokens)
826+
guard placeholderRanges.count == frames.count else {
827+
throw VLMError.processing(
828+
"Number of image placeholders does not match number of frames")
765829
}
766-
767-
let imagePixels: MLXArray?
768-
let image: LMInput.ProcessedImage?
769-
if !images.isEmpty {
770-
imagePixels = concatenated(images.map { $0.0 })
771-
image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 })
772-
} else {
773-
imagePixels = nil
774-
image = nil
830+
let replacementSequences = try frames.map { thw in
831+
let paddingCount = thw.product / mergeLength
832+
return try tokenizer.encode(
833+
text:
834+
"<|vision_start|>\(Array(repeating: "<|image_pad|>", count: paddingCount).joined())<|vision_end|>"
835+
)
775836
}
776-
777-
let videoPixels: MLXArray?
778-
let video: LMInput.ProcessedVideo?
779-
if !videos.isEmpty {
780-
videoPixels = concatenated(videos.map { $0.0 })
781-
video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 })
782-
} else {
783-
videoPixels = nil
784-
video = nil
837+
// Build the final array
838+
var result: [Int] = []
839+
var currentIndex = promptTokens.startIndex
840+
for (range, replacement) in zip(placeholderRanges, replacementSequences) {
841+
// Add tokens before the placeholder
842+
result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound])
843+
// Add replacement sequence
844+
result.append(contentsOf: replacement)
845+
currentIndex = range.upperBound
785846
}
786-
787-
// processing_qwen2_vl.Qwen2VLProcessor
788-
let prompt = prepare(
789-
prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw)
790-
let promptTokens = try tokenizer.encode(text: prompt)
847+
// Add any remaining tokens after the last replacement
848+
if currentIndex < promptTokens.endIndex {
849+
result.append(contentsOf: promptTokens[currentIndex...])
850+
}
851+
promptTokens = result
791852
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
792853
let mask = ones(like: promptArray).asType(.int8)
793-
794-
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video)
854+
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
795855
}
796-
797856
}
798857

799858
// MARK: - Model
@@ -821,18 +880,18 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
821880
self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration)
822881
}
823882

824-
private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?)
883+
private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?)
825884
-> MLXArray
826885
{
827-
guard let pixelValues, let gridThw else {
886+
guard let pixelValues, let frames else {
828887
return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis])
829888
}
830889

831890
// Get the input embeddings from the language model
832891
let inputEmbeds = languageModel.model.embedTokens(inputIds)
833892

834893
// Get the ouptut hidden states from the vision model
835-
var hiddenStates = self.visionModel(pixelValues, gridThw: gridThw)
894+
var hiddenStates = self.visionModel(pixelValues, frames: frames)
836895

837896
if hiddenStates.ndim == 2 {
838897
hiddenStates = hiddenStates[.newAxis, 0..., 0...]
@@ -871,6 +930,8 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
871930
public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws
872931
-> PrepareResult
873932
{
933+
let frames = input.image?.frames
934+
874935
let dtype = visionModel.patchEmbed.proj.weight.dtype
875936

876937
let imageGridThw = input.image?.imageGridThw
@@ -891,7 +952,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
891952
}
892953

893954
let inputEmbeddings = self.inputEmbeddings(
894-
inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw)
955+
inputIds: input.text.tokens, pixelValues: pixels, frames: frames)
895956

896957
let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings)
897958

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public enum VLMError: Error {
1111
case maskRequired
1212
case singleImageAllowed
1313
case imageProcessingFailure(String)
14+
case processing(String)
1415
}
1516

1617
public struct BaseProcessorConfiguration: Codable, Sendable {

0 commit comments

Comments
 (0)