Skip to content

Use chat templates for vision models #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions Applications/VLMEval/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,33 @@ class VLMEvaluator {
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))

let result = try await modelContainer.perform { context in
let images: [UserInput.Image] = image != nil ? [.ciImage(image!)] : []
let videos: [UserInput.Video] = videoURL != nil ? [.url(videoURL!)] : []
var userInput = UserInput(prompt: prompt, images: images, videos: videos)
let images: [UserInput.Image] =
if let image {
[UserInput.Image.ciImage(image)]
} else {
[]
}
let videos: [UserInput.Video] =
if let videoURL {
[.url(videoURL)]
} else {
[]
}
var userInput = UserInput(
messages: [
[
"role": "user",
"content": [
["type": "text", "text": prompt]
]
+ images.map { _ in
["type": "image"]
}
+ videos.map { _ in
["type": "video"]
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should happen inside Qwen2VLProcessor? As it stands the llm-tool doesn't work because it doesn't have this augmentation.

In the python code this is specific to the model, but handled outside the model/processing code. I think it belongs with the UserInputProcessor as that is where all of these would come together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the construction of the messages to the app would make this more flexible. I can imagine that in the future there might be different message formats. I'm not certain about this, but this approach is working well for me in my app at the moment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it is, it means the llm-tool doesn't work for Qwen models -- it gets an error because it is missing the tokens.

On the python side this looks like:

    # Model to format mapping
    model_to_format = {
        # Models using message_list_with_image format
        "idefics2": "message_list_with_image",
        "idefics3": "message_list_with_image",
        "llava": "message_list_with_image",
        "llava_next": "message_list_with_image",
        "mllama": "message_list_with_image",
        # Models that can handle both image and video formats
        "qwen2_vl": (
            "message_video_with_text"
            if kwargs.get("video")
            else "message_list_with_image"
        ),
        "qwen2_5_vl": (
            "message_video_with_text"
            if kwargs.get("video")
            else "message_list_with_image"
        ),

and this matches the swift code:

    # Message format handlers
    def handle_list_with_image():
        content = [create_text_message(prompt)]
        if role == "user" and not skip_image_token:
            image_tokens = [{"type": "image"}] * num_images
            content = (
                image_tokens + content
                if model_name in ["pixtral", "idefics3"]
                else content + image_tokens
            )
        return {"role": role, "content": content}

So the problems of leaving it to the app are two-fold:

  • each application/tool has to have a copy of this code
  • the code varies per model type and the app needs to have a table mapping to the right message structure

Perhaps we could have a way to mark the messages as already being processed (or the UserInputProcessor could inspect the messages and detect that), leaving it up to the app, but I am not sure what the app would do different than the generic processing required by the model type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine that in the future there might be different message formats.

Yes, for sure there are, but per the python code they vary by model type and somewhat by the presence of video vs images.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue might be multi-turn conversations. I don't think UserInput has a way to designate to which turn an image or video belongs, so perhaps this is best left to the app?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that works ok -- we can get experience with it and take our time on the longer term approach. I suggest:

  • add this code to llm-tool so that it can still work with Qwen
  • add a comment to the code indicating that it may be model dependent (in case somebody adds a model where this doesn't work)

Then we can consider the approach at our leisure. Sound good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it is, apps already have to handle how the system message is treated in the construction of messages, since some models support a system role and others don't.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a lot of value getting the chat template code merged sooner rather than perfecting everything around it :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great! I'll make the changes to llm-tool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llm-tool now works.

]
], images: images, videos: videos)
userInput.processing.resize = .init(width: 448, height: 448)

let input = try await context.processor.prepare(input: userInput)
Expand Down
2 changes: 1 addition & 1 deletion Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
// but that is not public so just fall back to text
let prompt = input.prompt
.asMessages()
.compactMap { $0["content"] }
.compactMap { $0["content"] as? String }
.joined(separator: ". ")
let promptTokens = tokenizer.encode(text: prompt)
return LMInput(tokens: MLXArray(promptTokens))
Expand Down
14 changes: 8 additions & 6 deletions Libraries/MLXLMCommon/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,16 @@ public struct LMInput {
/// Representation of prepared input image(s).
public struct ProcessedImage {

/// Concatenated pixels from one or more images
public let pixels: MLXArray
public let imageGridThw: [THW]?
/// Time, height, and width of the images
public let frames: [THW]?

public init(
pixels: MLXArray, imageGridThw: [THW]? = nil
pixels: MLXArray, frames: [THW]? = nil
) {
self.pixels = pixels
self.imageGridThw = imageGridThw
self.frames = frames
}
}

Expand All @@ -85,13 +87,13 @@ public struct LMInput {
public struct ProcessedVideo {

public let pixels: MLXArray
public let videoGridThw: [THW]?
public let frames: [THW]?

public init(
pixels: MLXArray, videoGridThw: [THW]? = nil
pixels: MLXArray, frames: [THW]? = nil
) {
self.pixels = pixels
self.videoGridThw = videoGridThw
self.frames = frames
}
}

Expand Down
11 changes: 7 additions & 4 deletions Libraries/MLXLMCommon/UserInput.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@ import Foundation
import MLX
import Tokenizers

public typealias Message = [String: Any]

/// Container for raw user input.
///
/// A ``UserInputProcessor`` can convert this to ``LMInput``.
/// See also ``ModelContext``.
public struct UserInput: Sendable {

/// Representation of a prompt or series of messages (conversation).
public enum Prompt: Sendable, CustomStringConvertible {
case text(String)
case messages([[String: String]])
case messages([Message])

public func asMessages() -> [[String: String]] {
public func asMessages() -> [Message] {
switch self {
case .text(let text):
return [["role": "user", "content": text]]
Expand Down Expand Up @@ -144,11 +145,13 @@ public struct UserInput: Sendable {
}

public init(
messages: [[String: String]], images: [Image] = [Image](), tools: [ToolSpec]? = nil,
messages: [Message], images: [Image] = [Image](), videos: [Video] = [Video](),
tools: [ToolSpec]? = nil,
additionalContext: [String: Any]? = nil
) {
self.prompt = .messages(messages)
self.images = images
self.videos = videos
self.tools = tools
self.additionalContext = additionalContext
}
Expand Down
2 changes: 1 addition & 1 deletion Libraries/MLXVLM/Models/Idefics3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ public class Idefics3Processor: UserInputProcessor {
}

public func prepare(input: UserInput) throws -> LMInput {
let prompt = input.prompt.asMessages().last?["content"] ?? ""
let prompt = input.prompt.asMessages().last?["content"] as? String ?? ""

if input.images.isEmpty {
// No image scenario
Expand Down
2 changes: 1 addition & 1 deletion Libraries/MLXVLM/Models/Paligemma.swift
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
}

// this doesn't have a chat template so just use the last message.
var prompt = input.prompt.asMessages().last?["content"] ?? ""
var prompt = input.prompt.asMessages().last?["content"] as? String ?? ""

// based on transformers/processing_paligemma
let count = input.images.count * config.imageSequenceLength
Expand Down
Loading