Skip to content

mtmd : Support jinja in libmtmd (Only for QwenVL and Qwen Omni) #14730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,8 @@ static common_chat_params common_chat_templates_apply_jinja(
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
bool concat_text = !inputs.no_part_concat && !tmpl.original_caps().requires_typed_content;
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, concat_text);
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.enable_thinking = inputs.enable_thinking;
Expand Down
3 changes: 3 additions & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ struct common_chat_templates_inputs {
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;

//If true, the jinja won't concat content parts into single part. That's useful for media parts
bool no_part_concat = false;
};

struct common_chat_params {
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,9 @@ extern "C" {
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);

LLAMA_API llama_token llama_vocab_image_token(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_audio_token(const struct llama_vocab * vocab);

DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead");
DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
{LLM_KV_TOKENIZER_IMAGE_ID, "tokenizer.ggml.image_token_id" },
{LLM_KV_TOKENIZER_AUDIO_ID, "tokenizer.ggml.audio_token_id" },

{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
Expand Down
5 changes: 5 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,13 @@ enum llm_kv {

LLM_KV_CLASSIFIER_OUTPUT_LABELS,


LLM_KV_TOKENIZER_IMAGE_ID,
LLM_KV_TOKENIZER_AUDIO_ID,

LLM_KV_SHORTCONV_L_CACHE,


// deprecated:
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
Expand Down
40 changes: 40 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,8 @@ struct llama_vocab::impl {
llama_token special_fim_pad_id = LLAMA_TOKEN_NULL;
llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
llama_token special_image_id = LLAMA_TOKEN_NULL;
llama_token special_audio_id = LLAMA_TOKEN_NULL;

// tokenizer flags
bool add_space_prefix = false;
Expand Down Expand Up @@ -1999,6 +2001,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false);
}

const int image_idx = gguf_find_key(ctx,kv(LLM_KV_TOKENIZER_IMAGE_ID).c_str());
if (image_idx != -1) {
special_image_id=gguf_get_val_u32(ctx,image_idx);
}
const int audio_idx = gguf_find_key(ctx,kv(LLM_KV_TOKENIZER_AUDIO_ID).c_str());
if (audio_idx != -1) {
special_audio_id=gguf_get_val_u32(ctx,audio_idx);
}
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
if (token_idx == -1) {
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
Expand Down Expand Up @@ -2034,6 +2044,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
token_data.score = scores ? scores[i] : 0.0f;
token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;


if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file
switch(toktypes[i]) {
case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break;
Expand Down Expand Up @@ -2094,6 +2105,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
{ LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id },
{ LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id },
{ LLM_KV_TOKENIZER_IMAGE_ID, special_image_id },
{ LLM_KV_TOKENIZER_AUDIO_ID, special_audio_id },

// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id },
Expand Down Expand Up @@ -2172,6 +2185,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
}
}
}
if (special_image_id==LLAMA_TOKEN_NULL) {
if (t.first=="<|IMAGE|>" || t.first=="<IMAGE>") {
special_image_id=t.second;
}
}
if (special_audio_id==LLAMA_TOKEN_NULL) {
if (t.first=="<|AUDIO|>" || t.first=="<AUDIO>") {
special_audio_id=t.second;
}
}

// find FIM_PRE token: "<|fim_prefix|>", "<fim-prefix>", "<PRE>", etc.
if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
Expand Down Expand Up @@ -3354,6 +3377,15 @@ llama_token llama_vocab::token_fim_sep() const {
return pimpl->special_fim_sep_id;
}


llama_token llama_vocab::token_image() const {
return pimpl->special_image_id;
}

llama_token llama_vocab::token_audio() const {
return pimpl->special_audio_id;
}

llama_token llama_vocab::token_mask() const {
return pimpl->special_mask_id;
}
Expand Down Expand Up @@ -3598,6 +3630,14 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
return vocab->token_fim_sep();
}

llama_token llama_vocab_image_token(const struct llama_vocab * vocab) {
return vocab->token_image();
}

llama_token llama_vocab_audio_token(const struct llama_vocab * vocab) {
return vocab->token_audio();
}

llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
return vocab->token_mask();
}
Expand Down
2 changes: 2 additions & 0 deletions src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct llama_vocab {
llama_token token_fim_rep() const;
llama_token token_fim_sep() const;

llama_token token_image() const;
llama_token token_audio() const;
bool get_add_space_prefix () const;
bool get_add_bos () const;
bool get_add_eos () const;
Expand Down
2 changes: 1 addition & 1 deletion tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ enum patch_merge_type {

struct clip_hparams {
int32_t image_size;
int32_t patch_size;
int32_t patch_size=INT_MAX;
int32_t n_embd;
int32_t n_ff;
int32_t projection_dim;
Expand Down
44 changes: 27 additions & 17 deletions tools/mtmd/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ struct mtmd_cli_context {
}
};

static int generate_response(mtmd_cli_context & ctx, int n_predict) {
static std::string generate_response(mtmd_cli_context & ctx, int n_predict) {
llama_tokens generated_tokens;
std::string response = "";
for (int i = 0; i < n_predict; i++) {
if (i > n_predict || !g_is_generating || g_is_interrupted) {
LOG("\n");
Expand All @@ -176,8 +177,9 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
LOG("\n");
break; // end of generation
}

LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
std::string piece=common_token_to_piece(ctx.lctx, token_id);
LOG("%s", piece.c_str());
response += piece;
fflush(stdout);

if (g_is_interrupted) {
Expand All @@ -190,17 +192,18 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
if (llama_decode(ctx.lctx, ctx.batch)) {
LOG_ERR("failed to decode token\n");
return 1;
return "";
}
}
return 0;
return response;
}

static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
static int eval_message(mtmd_cli_context & ctx, const std::vector<common_chat_msg> & messages, bool add_bos = false) {
common_chat_templates_inputs tmpl_inputs;
tmpl_inputs.messages = {msg};
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = true;
tmpl_inputs.use_jinja = false; // jinja is buggy here
tmpl_inputs.no_part_concat=true;
tmpl_inputs.use_jinja = true; // jinja is bughigy here
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());

Expand Down Expand Up @@ -303,10 +306,10 @@ int main(int argc, char ** argv) {
return 1; // error is already printed by libmtmd
}
}
if (eval_message(ctx, msg, true)) {
if (eval_message(ctx,{msg} , true)) {
return 1;
}
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
if (!g_is_interrupted && generate_response(ctx, n_predict).empty()) {
return 1;
}

Expand All @@ -324,7 +327,7 @@ int main(int argc, char ** argv) {

bool is_first_msg = true;
std::string content;

std::vector<common_chat_msg> messages;
while (!g_is_interrupted) {
g_is_generating = false;
LOG("\n> ");
Expand Down Expand Up @@ -357,24 +360,31 @@ int main(int argc, char ** argv) {
std::string media_path = line.substr(7);
if (ctx.load_media(media_path)) {
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
content += mtmd_default_marker();
//content += mtmd_default_marker();
common_chat_msg msg;
msg.content_parts.push_back({"image",""});
messages.push_back(std::move(msg));
}
// else, error is already printed by libmtmd
continue;
} else {
content += line;
}
common_chat_msg msg;
msg.role = "user";
msg.content = content;
int ret = eval_message(ctx, msg, is_first_msg);
msg.content = line;
messages.push_back(std::move(msg));
int ret = eval_message(ctx, messages, is_first_msg);
if (ret) {
return 1;
}
if (g_is_interrupted) break;
if (generate_response(ctx, n_predict)) {
auto response=generate_response(ctx, n_predict);
if (response.empty()) {
return 1;
}
common_chat_msg response_message;
response_message.role = "assistant";
response_message.content = response;
messages.push_back(response_message);
content.clear();
is_first_msg = false;
}
Expand Down
66 changes: 62 additions & 4 deletions tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,10 @@ struct mtmd_context {

} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
img_beg = "<|vision_start|>";
img_end = "<|vision_end|>";
//There is a valid reason why they are commented here. QWENVL and Qwen Omni has their own tokens for image.\
// The jinja produce something like that <|vision_bos|><|IMAGE|><|vision_eos|>.
// img_beg = "<|vision_start|>";
// img_end = "<|vision_end|>";

} else if (proj == PROJECTOR_TYPE_LLAMA4) {
// (more details in mtmd_context constructor)
Expand Down Expand Up @@ -385,10 +387,21 @@ struct mtmd_tokenizer {

int32_t tokenize(mtmd_input_chunks * output) {
cur.entries.clear();
std::vector<std::string> parts = split_text(input_text, ctx->media_marker);

size_t i_bm = 0; // index of the current bitmap
llama_token imageTokenID =llama_vocab_image_token(vocab);
std::string imageToken;
std::vector<std::string> delimiters;
delimiters.push_back(ctx->media_marker);

if (imageTokenID!=LLAMA_TOKEN_NULL) {
imageToken = llama_vocab_get_text(vocab,imageTokenID);
delimiters.push_back(imageToken);
}

std::vector<std::string> parts = split_text_multi(input_text, delimiters);
for (auto & part : parts) {
if (part == ctx->media_marker) {
if (part == ctx->media_marker || part==imageToken) {
// this is a marker, we should add the next bitmap
if (i_bm >= bitmaps.size()) {
LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
Expand Down Expand Up @@ -707,6 +720,51 @@ struct mtmd_tokenizer {
return result;
}

static std::vector<std::string> split_text_multi(const std::string& input,
const std::vector<std::string>& delimiters) {
std::vector<std::string> result;
if (input.empty()) {
return result;
}

size_t pos = 0;
while (pos < input.length()) {
// Find the earliest occurring delimiter
size_t best_match_pos = std::string::npos;
std::string best_delimiter;

for (const auto& delimiter : delimiters) {
size_t match_pos = input.find(delimiter, pos);
if (match_pos != std::string::npos &&
(best_match_pos == std::string::npos || match_pos < best_match_pos)) {
best_match_pos = match_pos;
best_delimiter = delimiter;
}
}

if (best_match_pos == std::string::npos) {
// No more delimiters found, add remaining text
if (pos < input.length()) {
result.push_back(input.substr(pos));
}
break;
}

// Add text before delimiter (if any)
if (best_match_pos > pos) {
result.push_back(input.substr(pos, best_match_pos - pos));
}

// Add the delimiter itself
result.push_back(best_delimiter);

// Move past the delimiter
pos = best_match_pos + best_delimiter.length();
}

return result;
}

// copied from common_tokenize
static std::vector<llama_token> mtmd_tokenize_text_internal(
const struct llama_vocab * vocab,
Expand Down