Skip to content

[Model] Qwen-2-VL Support #3125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/mlc_llm/conversation_template/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,17 @@
stop_token_ids=[151643, 151645],
)
)

ConvTemplateRegistry.register_conv_template(
Conversation(
name="qwen2-vl",
system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n",
system_message="You are a helpful assistant.",
roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"},
seps=["<|im_end|>\n"],
role_content_sep="\n",
role_empty_sep="\n",
stop_str=["<|endoftext|>", "<|im_end|>"],
stop_token_ids=[151643, 151645],
)
)
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"chatml",
"chatml_nosystem",
"qwen2",
"qwen2-vl",
"open_hermes_mistral",
"neural_hermes_mistral",
"llama_default",
Expand Down
15 changes: 15 additions & 0 deletions python/mlc_llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .qwen import qwen_loader, qwen_model, qwen_quantization
from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization
from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization
from .qwen2_vl import qwen2_vl_loader, qwen2_vl_model, qwen2_vl_quantization
from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization
from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization
Expand Down Expand Up @@ -299,6 +300,20 @@ class Model:
"ft-quant": qwen2_moe_quantization.ft_quant,
},
),
"qwen2_vl": Model(
name="qwen2_vl",
model=qwen2_vl_model.QWen2VLLMHeadModel,
config=qwen2_vl_model.QWen2VLConfig,
source={
"huggingface-torch": qwen2_vl_loader.huggingface,
"huggingface-safetensor": qwen2_vl_loader.huggingface,
},
quantize={
"no-quant": qwen2_vl_quantization.no_quant,
"group-quant": qwen2_vl_quantization.group_quant,
"ft-quant": qwen2_vl_quantization.ft_quant,
},
),
"deepseek_v2": Model(
name="deepseek_v2",
model=deepseek_v2_model.DeepseekV2ForCausalLM,
Expand Down
Empty file.
86 changes: 86 additions & 0 deletions python/mlc_llm/model/qwen2_vl/qwen2_vl_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Configuration classes for Qwen2 VL model."""

import dataclasses
from typing import Any, Dict, Optional

from mlc_llm.support.config import ConfigBase
from mlc_llm.model.qwen2.qwen2_model import QWen2Config

@dataclasses.dataclass
class QWen2VLVisionConfig(ConfigBase):
"""Configuration for the vision part of Qwen2 VL."""

hidden_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
patch_size: int = 14
merge_size: int = 2
image_size: int = 448 # Default max image size
num_channels: int = 3
layer_norm_eps: float = 1e-6
max_patches: int = 1024 # Maximum number of patches after merging
hidden_act: str = "gelu"
dtype: str = "float32"
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
num_rope_scales: int = 4
rope_theta: float = 10000

@dataclasses.dataclass
class QWen2VLConfig(QWen2Config):
"""Configuration for the complete Qwen2 VL model."""

vision_config: Optional[QWen2VLVisionConfig] = None
image_size: int = 448
min_image_size: int = 224
max_image_size: int = 448
patch_size: int = 14
merge_size: int = 2
temporal_patch_size: int = 2
min_patch_size: int = 14
max_patch_size: int = 28
min_pixels: int = 56*56
max_pixels: int = 28*28*1280
dtype: str = "float32"
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
# First run parent class post init
super().__post_init__()

# Set up vision config if not provided
if self.vision_config is None:
self.vision_config = QWen2VLVisionConfig(
hidden_size=1024, # Vision hidden size
intermediate_size=4096, # Vision MLP size
num_hidden_layers=24, # Number of vision transformer layers
num_attention_heads=16, # Number of vision attention heads
patch_size=self.patch_size,
merge_size=self.merge_size,
image_size=self.image_size,
layer_norm_eps=1e-6,
dtype=self.dtype,
)

# Validate configuration
if self.patch_size < self.min_patch_size or self.patch_size > self.max_patch_size:
raise ValueError(
f"patch_size must be between {self.min_patch_size} and {self.max_patch_size}, "
f"got {self.patch_size}"
)

if self.image_size < self.min_image_size or self.image_size > self.max_image_size:
raise ValueError(
f"image_size must be between {self.min_image_size} and {self.max_image_size}, "
f"got {self.image_size}"
)

# Calculate maximum patches based on image size and patch size
max_h = self.max_image_size // (self.patch_size * self.merge_size)
max_w = self.max_image_size // (self.patch_size * self.merge_size)
self.vision_config.max_patches = max_h * max_w

# Add any additional kwargs
for k, v in self.kwargs.items():
if not hasattr(self, k):
setattr(self, k, v)
Loading