Skip to content

Commit b2ef5f1

Browse files
committed
Use chat template for Qwen 2 VL
1 parent d189ae3 commit b2ef5f1

File tree

8 files changed

+97
-80
lines changed

8 files changed

+97
-80
lines changed

Applications/VLMEval/ContentView.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,16 @@ class VLMEvaluator {
331331
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
332332

333333
let result = try await modelContainer.perform { context in
334-
var userInput = UserInput(prompt: prompt, images: [.ciImage(image)])
334+
var userInput = UserInput(
335+
messages: [
336+
[
337+
"role": "user",
338+
"content": [
339+
["type": "text", "text": prompt],
340+
["type": "image"],
341+
],
342+
]
343+
], images: [.ciImage(image)])
335344
userInput.processing.resize = .init(width: 448, height: 448)
336345

337346
let input = try await context.processor.prepare(input: userInput)

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
230230
// but that is not public so just fall back to text
231231
let prompt = input.prompt
232232
.asMessages()
233-
.compactMap { $0["content"] }
233+
.compactMap { $0["content"] as? String }
234234
.joined(separator: ". ")
235235
let promptTokens = tokenizer.encode(text: prompt)
236236
return LMInput(tokens: MLXArray(promptTokens))

Libraries/MLXLMCommon/LanguageModel.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ public struct LMInput {
6969
public struct ProcessedImage {
7070

7171
public let pixels: MLXArray
72-
public let imageGridThw: [THW]?
72+
public let frames: [THW]?
7373

7474
public init(
75-
pixels: MLXArray, imageGridThw: [THW]? = nil
75+
pixels: MLXArray, frames: [THW]? = nil
7676
) {
7777
self.pixels = pixels
78-
self.imageGridThw = imageGridThw
78+
self.frames = frames
7979
}
8080
}
8181

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ import CoreImage
44
import Foundation
55
import MLX
66

7+
public typealias Message = [String: Any]
8+
79
/// Container for raw user input.
810
///
911
/// A ``UserInputProcessor`` can convert this to ``LMInput``.
1012
/// See also ``ModelContext``.
1113
public struct UserInput: Sendable {
12-
1314
/// Representation of a prompt or series of messages (conversation).
1415
public enum Prompt: Sendable, CustomStringConvertible {
1516
case text(String)
16-
case messages([[String: String]])
17+
case messages([Message])
1718

18-
public func asMessages() -> [[String: String]] {
19+
public func asMessages() -> [Message] {
1920
switch self {
2021
case .text(let text):
2122
return [["role": "user", "content": text]]
@@ -116,7 +117,7 @@ public struct UserInput: Sendable {
116117
self.images = images
117118
}
118119

119-
public init(messages: [[String: String]], images: [Image] = [Image]()) {
120+
public init(messages: [Message], images: [Image] = [Image]()) {
120121
self.prompt = .messages(messages)
121122
self.images = images
122123
}

Libraries/MLXVLM/Models/Idefics3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor {
805805
}
806806

807807
public func prepare(input: UserInput) throws -> LMInput {
808-
let prompt = input.prompt.asMessages().last?["content"] ?? ""
808+
let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
809809

810810
if input.images.isEmpty {
811811
// No image scenario

Libraries/MLXVLM/Models/Paligemma.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
478478
}
479479

480480
// this doesn't have a chat template so just use the last message.
481-
var prompt = input.prompt.asMessages().last?["content"] ?? ""
481+
var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""
482482

483483
// based on transformers/processing_paligemma
484484
let count = input.images.count * config.imageSequenceLength

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ private enum Vision {
367367
}
368368

369369
public func callAsFunction(
370-
_ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
370+
_ x: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray
371371
) -> MLXArray {
372372
let sequenceLength = x.dim(0)
373-
let B = gridThw[0].t
373+
let B = frames[0].t
374374
let L = sequenceLength / B
375375

376376
let qkv = qkv(x)
@@ -435,13 +435,13 @@ private enum Vision {
435435
}
436436

437437
func callAsFunction(
438-
_ hiddenStates: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
438+
_ hiddenStates: MLXArray, frames: [THW], rotaryPositionEmbedding: MLXArray
439439
) -> MLXArray {
440440
var hiddenStates =
441441
hiddenStates
442442
+ attention(
443443
norm1(hiddenStates),
444-
gridThw: gridThw,
444+
frames: frames,
445445
rotaryPositionEmbedding: rotaryPositionEmbedding
446446
)
447447
hiddenStates = hiddenStates + mlp(norm2(hiddenStates))
@@ -479,10 +479,10 @@ private enum Vision {
479479
spatialMergeSize: 2)
480480
}
481481

482-
func rotaryPositionEmbedding(_ gridThw: [THW]) -> MLXArray {
482+
func rotaryPositionEmbedding(_ frames: [THW]) -> MLXArray {
483483
var positionIds = [MLXArray]()
484484

485-
for row in gridThw {
485+
for row in frames {
486486
let (t, h, w) = row.values
487487

488488
var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1)
@@ -516,22 +516,22 @@ private enum Vision {
516516
}
517517

518518
let indices = concatenated(positionIds, axis: 0)
519-
let maxGridSize = gridThw.lazy.map { max($0.h, $0.w) }.max() ?? 0
520-
let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[
519+
let maxFrameSize = frames.lazy.map { max($0.h, $0.w) }.max() ?? 0
520+
let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxFrameSize)[
521521
indices]
522522

523523
return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1)
524524
}
525525

526-
public func callAsFunction(_ hiddenStates: MLXArray, gridThw: [THW]) -> MLXArray {
526+
public func callAsFunction(_ hiddenStates: MLXArray, frames: [THW]) -> MLXArray {
527527
var hiddenStates = patchEmbed(hiddenStates)
528-
let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw)
528+
let rotaryPositionEmbedding = rotaryPositionEmbedding(frames)
529529

530-
let batchSize = gridThw.count
530+
let batchSize = frames.count
531531

532532
for block in blocks {
533533
hiddenStates = block(
534-
hiddenStates, gridThw: gridThw,
534+
hiddenStates, frames: frames,
535535
rotaryPositionEmbedding: rotaryPositionEmbedding)
536536
}
537537

@@ -585,6 +585,10 @@ private enum Vision {
585585
/// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``.
586586
public class Qwen2VLProcessor: UserInputProcessor {
587587

588+
enum Qwen2VLProcessorError: Error {
589+
case framesIsNil
590+
}
591+
588592
private let config: Qwen2VLProcessorConfiguration
589593
private let tokenizer: any Tokenizer
590594

@@ -686,72 +690,74 @@ public class Qwen2VLProcessor: UserInputProcessor {
686690
return (flattenedPatches, .init(gridT, gridH, gridW))
687691
}
688692

689-
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) -> String {
690-
// the tokenizer does have a chat template and it expects messages
691-
// like this:
692-
//
693-
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
694-
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
695-
//
696-
// The output of the prompt template is fed into
697-
// image_processing_qwen2_vl.preprocess where it is further augmented
698-
// by replacing tokens according to imageTHW.
699-
//
700-
// Neither the structured content nor the postprocessing of the template
701-
// are supported in current Tokenizer/Jinja (swift) so handle that here.
702-
703-
var messages = prompt.asMessages()
704-
if messages[0]["role"] != "system" {
693+
private func prepareMessages(_ messages: [Message]) -> [Message] {
694+
var messages = messages
695+
// Add system message if not present
696+
if let role = messages[0]["role"] as? String, role != "system" {
705697
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
706698
}
707-
708-
let lastIndex = messages.count - 1
709-
var lastMessage = messages[lastIndex]["content"] ?? ""
710-
711-
// image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
712-
let mergeLength = config.mergeSize * config.mergeSize
713-
for thw in imageTHW ?? [] {
714-
lastMessage += "<|vision_start|>"
715-
lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
716-
.joined()
717-
lastMessage += "<|vision_end|>"
718-
}
719-
720-
messages[lastIndex]["content"] = lastMessage
721-
722-
return
723-
messages
724-
.map {
725-
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
726-
}
727-
.joined(separator: "\n")
728-
+ "\n<|im_start|>assistant\n"
699+
return messages
729700
}
730701

702+
// public func prepare(prompt: UserInput.Prompt, frames: [THW]?) throws -> String {
703+
// let messages = prepareMessages(prompt.asMessages())
704+
// let tokens = try tokenizer.applyChatTemplate(messages: messages)
705+
// return tokenizer.decode(tokens: tokens)
706+
// }
707+
731708
public func prepare(input: UserInput) throws -> LMInput {
709+
// Text-only input
732710
if input.images.isEmpty {
733-
// just a straight text prompt
734-
let prompt = prepare(prompt: input.prompt, imageTHW: nil)
735-
let promptTokens = try tokenizer.encode(text: prompt)
711+
let messages = input.prompt.asMessages()
712+
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
736713
return LMInput(tokens: MLXArray(promptTokens))
737714
}
738-
739-
// image_processing_qwen2_vl.preprocess
740-
let images = try input.images.map {
715+
// Input with images
716+
let pixelsAndFrames = try input.images.map {
741717
try preprocess(images: [$0.asCIImage()], processing: input.processing)
742718
}
743-
let pixels = concatenated(images.map { $0.0 })
744-
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })
745-
746-
// processing_qwen2_vl.Qwen2VLProcessor
747-
let prompt = prepare(prompt: input.prompt, imageTHW: image.imageGridThw)
748-
let promptTokens = try tokenizer.encode(text: prompt)
719+
let pixelsConcatenated = concatenated(pixelsAndFrames.map { $0.0 })
720+
let image = LMInput.ProcessedImage(
721+
pixels: pixelsConcatenated, frames: pixelsAndFrames.map { $0.1 })
722+
let messages = prepareMessages(input.prompt.asMessages())
723+
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
724+
// Replace single image pad token with correct number for each image
725+
let mergeLength = config.mergeSize * config.mergeSize
726+
let imagePlaceholderTokens = try tokenizer.encode(
727+
text: "<|vision_start|><|image_pad|><|vision_end|>")
728+
guard let frames = image.frames else {
729+
throw Qwen2VLProcessorError.framesIsNil
730+
}
731+
let placeholderRanges = promptTokens.ranges(of: imagePlaceholderTokens)
732+
guard placeholderRanges.count == frames.count else {
733+
throw VLMError.processing("Number of image placeholders does not match number of frames")
734+
}
735+
let replacementSequences = try frames.map { thw in
736+
let paddingCount = thw.product / mergeLength
737+
return try tokenizer.encode(
738+
text:
739+
"<|vision_start|>\(Array(repeating: "<|image_pad|>", count: paddingCount).joined())<|vision_end|>"
740+
)
741+
}
742+
// Build the final array
743+
var result: [Int] = []
744+
var currentIndex = promptTokens.startIndex
745+
for (range, replacement) in zip(placeholderRanges, replacementSequences) {
746+
// Add tokens before the placeholder
747+
result.append(contentsOf: promptTokens[currentIndex ..< range.lowerBound])
748+
// Add replacement sequence
749+
result.append(contentsOf: replacement)
750+
currentIndex = range.upperBound
751+
}
752+
// Add any remaining tokens after the last replacement
753+
if currentIndex < promptTokens.endIndex {
754+
result.append(contentsOf: promptTokens[currentIndex...])
755+
}
756+
promptTokens = result
749757
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
750758
let mask = ones(like: promptArray).asType(.int8)
751-
752759
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
753760
}
754-
755761
}
756762

757763
// MARK: - Model
@@ -779,18 +785,18 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
779785
self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration)
780786
}
781787

782-
private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?)
788+
private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?)
783789
-> MLXArray
784790
{
785-
guard let pixelValues, let gridThw else {
791+
guard let pixelValues, let frames else {
786792
return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis])
787793
}
788794

789795
// Get the input embeddings from the language model
790796
let inputEmbeds = languageModel.model.embedTokens(inputIds)
791797

792798
// Get the ouptut hidden states from the vision model
793-
var hiddenStates = self.visionModel(pixelValues, gridThw: gridThw)
799+
var hiddenStates = self.visionModel(pixelValues, frames: frames)
794800

795801
if hiddenStates.ndim == 2 {
796802
hiddenStates = hiddenStates[.newAxis, 0..., 0...]
@@ -820,13 +826,13 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
820826
public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws
821827
-> PrepareResult
822828
{
823-
let gridThw = input.image?.imageGridThw
829+
let frames = input.image?.frames
824830

825831
let dtype = visionModel.patchEmbed.proj.weight.dtype
826832
let pixels = input.image?.pixels.asType(dtype)
827833

828834
let inputEmbeddings = self.inputEmbeddings(
829-
inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw)
835+
inputIds: input.text.tokens, pixelValues: pixels, frames: frames)
830836

831837
let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings)
832838

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public enum VLMError: Error {
1111
case maskRequired
1212
case singleImageAllowed
1313
case imageProcessingFailure(String)
14+
case processing(String)
1415
}
1516

1617
public struct BaseProcessorConfiguration: Codable, Sendable {

0 commit comments

Comments
 (0)