Skip to content

Commit 033d38e

Browse files
committed
refactor UserInputProcessor as they're the same for Qwen2VL and Qwen2.5VL
1 parent e3cebcc commit 033d38e

File tree

3 files changed

+234
-437
lines changed

3 files changed

+234
-437
lines changed

Libraries/MLXVLM/Models/Qwen25VL.swift

Lines changed: 8 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -974,218 +974,16 @@ public struct Qwen25VLConfiguration: Codable, Sendable {
974974
}
975975
}
976976

977-
public class Qwen25VLProcessor: UserInputProcessor {
978-
private let config: Qwen25VLProcessorConfiguration
979-
private let tokenizer: any Tokenizer
977+
// MARK: - Processor
980978

981-
public init(_ config: Qwen25VLProcessorConfiguration, tokenizer: any Tokenizer) {
982-
self.config = config
983-
self.tokenizer = tokenizer
984-
}
985-
986-
// image_processing_qwen2_vl.smart_resize
987-
private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int)
988-
throws -> (Int, Int)
989-
{
990-
if height < factor {
991-
throw VLMError.imageProcessingFailure(
992-
"height: \(height) must be larger than factor: \(factor)")
993-
}
994-
if width < factor {
995-
throw VLMError.imageProcessingFailure(
996-
"width: \(width) must be larger than factor: \(factor)")
997-
}
998-
if max(height, width) / min(height, width) > 200 {
999-
throw VLMError.imageProcessingFailure(
1000-
"absolute aspect ratio must be smaller than 200: \(width)x\(height)")
1001-
}
1002-
1003-
var hBar = Int(round(Float(height) / Float(factor))) * factor
1004-
var wBar = Int(round(Float(width) / Float(factor))) * factor
1005-
1006-
if hBar * wBar > maxPixels {
1007-
let beta = sqrt(Float(height * width) / Float(maxPixels))
1008-
hBar = Int(floor(Float(height) / beta / Float(factor))) * factor
1009-
wBar = Int(floor(Float(width) / beta / Float(factor))) * factor
1010-
} else if hBar * wBar < minPixels {
1011-
let beta = sqrt(Float(minPixels) / Float(height * width))
1012-
hBar = Int(floor(Float(height) * beta / Float(factor))) * factor
1013-
wBar = Int(floor(Float(width) * beta / Float(factor))) * factor
1014-
}
1015-
return (hBar, wBar)
1016-
}
1017-
1018-
public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> (
1019-
MLXArray, THW
1020-
) {
1021-
// first apply the user requested resizing, etc. if any
1022-
let images = images.map { MediaProcessing.apply($0, processing: processing) }
1023-
1024-
// image_processing_qwen2_vl._preprocess
1025-
1026-
let size = images[0].extent.size
1027-
let (resizedHeight, resizedWidth) = try targetSize(
1028-
height: Int(size.height), width: Int(size.width),
1029-
factor: config.patchSize * config.mergeSize,
1030-
minPixels: config.minPixels, maxPixels: config.maxPixels)
1031-
let resizedSize = CGSize(width: resizedWidth, height: resizedHeight)
1032-
1033-
let processedImages =
1034-
try images
1035-
.map {
1036-
MediaProcessing.inSRGBToneCurveSpace($0)
1037-
}
1038-
.map {
1039-
return MediaProcessing.resampleBicubic($0, to: resizedSize)
1040-
}
1041-
.map {
1042-
MediaProcessing.normalize(
1043-
$0, mean: config.imageMeanTuple, std: config.imageStdTuple)
1044-
}
1045-
.map {
1046-
MediaProcessing.asMLXArray($0)
1047-
}
1048-
1049-
var patches = concatenated(processedImages)
1050-
let mod = patches.dim(0) % config.temporalPatchSize
1051-
if mod != 0 {
1052-
let lastPatch = patches[-1, .ellipsis]
1053-
let lastPatchRepeated = tiled(
1054-
lastPatch, repetitions: [config.temporalPatchSize - mod, 1, 1, 1])
1055-
patches = concatenated([patches, lastPatchRepeated])
1056-
}
1057-
let channel = patches.dim(1)
1058-
let gridT = patches.dim(0) / self.config.temporalPatchSize
1059-
let gridH = resizedHeight / self.config.patchSize
1060-
let gridW = resizedWidth / self.config.patchSize
1061-
1062-
patches = patches.reshaped(
1063-
gridT,
1064-
config.temporalPatchSize,
1065-
channel,
1066-
gridH / config.mergeSize,
1067-
config.mergeSize,
1068-
config.patchSize,
1069-
gridW / config.mergeSize,
1070-
config.mergeSize,
1071-
config.patchSize
1072-
)
1073-
patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8)
1074-
1075-
let flattenedPatches = patches.reshaped(
1076-
gridT * gridH * gridW,
1077-
channel * config.temporalPatchSize * config.patchSize * config.patchSize
1078-
)
1079-
1080-
return (flattenedPatches, .init(gridT, gridH, gridW))
1081-
}
1082-
1083-
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?, videoTHW: [THW]?) -> String {
1084-
// the tokenizer does have a chat template and it expects messages
1085-
// like this:
1086-
//
1087-
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
1088-
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
1089-
//
1090-
// The output of the prompt template is fed into
1091-
// image_processing_qwen2_vl.preprocess where it is further augmented
1092-
// by replacing tokens according to imageTHW.
1093-
//
1094-
// Neither the structured content nor the postprocessing of the template
1095-
// are supported in current Tokenizer/Jinja (swift) so handle that here.
1096-
1097-
var messages = prompt.asMessages()
1098-
if messages[0]["role"] != "system" {
1099-
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
1100-
}
1101-
1102-
let lastIndex = messages.count - 1
1103-
var lastMessage = messages[lastIndex]["content"] ?? ""
1104-
1105-
// image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
1106-
let mergeLength = config.mergeSize * config.mergeSize
1107-
for thw in imageTHW ?? [] {
1108-
lastMessage += "<|vision_start|>"
1109-
lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
1110-
.joined()
1111-
lastMessage += "<|vision_end|>"
1112-
}
1113-
1114-
for thw in videoTHW ?? [] {
1115-
lastMessage += "<|vision_start|>"
1116-
lastMessage += Array(repeating: "<|video_pad|>", count: thw.product / mergeLength)
1117-
.joined()
1118-
lastMessage += "<|vision_end|>"
1119-
}
1120-
1121-
messages[lastIndex]["content"] = lastMessage
1122-
1123-
return
1124-
messages
1125-
.map {
1126-
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
1127-
}
1128-
.joined(separator: "\n")
1129-
+ "\n<|im_start|>assistant\n"
1130-
}
1131-
1132-
public func prepare(input: UserInput) async throws -> LMInput {
1133-
if input.images.isEmpty && input.videos.isEmpty {
1134-
// just a straight text prompt
1135-
let prompt = prepare(prompt: input.prompt, imageTHW: nil, videoTHW: nil)
1136-
let promptTokens = try tokenizer.encode(text: prompt)
1137-
return LMInput(tokens: MLXArray(promptTokens))
1138-
}
1139-
1140-
// image_processing_qwen2_vl.preprocess
1141-
let images = try input.images.map {
1142-
try preprocess(images: [$0.asCIImage()], processing: input.processing)
1143-
}
1144-
1145-
var videosAsImageSequences = [[CIImage]]()
1146-
for video in input.videos {
1147-
if let imageSequence = try? await MediaProcessing.asCIImageSequence(
1148-
video.asAVAsset(), samplesPerSecond: 2)
1149-
{
1150-
videosAsImageSequences.append(imageSequence)
1151-
}
1152-
}
1153-
let videos = try videosAsImageSequences.map {
1154-
try preprocess(images: $0, processing: input.processing)
1155-
}
1156-
1157-
let imagePixels: MLXArray?
1158-
let image: LMInput.ProcessedImage?
1159-
if !images.isEmpty {
1160-
imagePixels = concatenated(images.map { $0.0 })
1161-
image = LMInput.ProcessedImage(pixels: imagePixels!, imageGridThw: images.map { $0.1 })
1162-
} else {
1163-
imagePixels = nil
1164-
image = nil
1165-
}
1166-
1167-
let videoPixels: MLXArray?
1168-
let video: LMInput.ProcessedVideo?
1169-
if !videos.isEmpty {
1170-
videoPixels = concatenated(videos.map { $0.0 })
1171-
video = LMInput.ProcessedVideo(pixels: videoPixels!, videoGridThw: videos.map { $0.1 })
1172-
} else {
1173-
videoPixels = nil
1174-
video = nil
1175-
}
1176-
1177-
let prompt = prepare(
1178-
prompt: input.prompt, imageTHW: image?.imageGridThw, videoTHW: video?.videoGridThw)
1179-
let promptTokens = try tokenizer.encode(text: prompt)
1180-
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
1181-
let mask = ones(like: promptArray).asType(.int8)
1182-
1183-
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image, video: video)
1184-
}
1185-
}
1186-
1187-
public struct Qwen25VLProcessorConfiguration: Codable, Sendable {
979+
/// Qwen25VL VLM `UserInputProcessor`.
980+
///
981+
/// This is meant to be used with ``Qwen25VL`` and is typically created by ``VLMModelFactory``.
982+
///
983+
public typealias Qwen25VLProcessor = QwenVLProcessor<Qwen25VLProcessorConfiguration>
1188984

985+
// Configuration for ``Qwen25VLProcessor``
986+
public struct Qwen25VLProcessorConfiguration: QwenVLProcessorConfiguration {
1189987
public let imageMean: [CGFloat]
1190988
public let imageStd: [CGFloat]
1191989
public let maxPixels: Int
@@ -1194,13 +992,6 @@ public struct Qwen25VLProcessorConfiguration: Codable, Sendable {
1194992
public let patchSize: Int
1195993
public let temporalPatchSize: Int
1196994

1197-
public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) {
1198-
(imageMean[0], imageMean[1], imageMean[2])
1199-
}
1200-
public var imageStdTuple: (CGFloat, CGFloat, CGFloat) {
1201-
(imageStd[0], imageStd[1], imageStd[2])
1202-
}
1203-
1204995
enum CodingKeys: String, CodingKey {
1205996
case imageMean = "image_mean"
1206997
case imageStd = "image_std"

0 commit comments

Comments
 (0)