From bf4d631ab03737db198cb765f8046ac29e3b10c9 Mon Sep 17 00:00:00 2001 From: Nihal John George Date: Mon, 10 Feb 2025 09:55:08 -0500 Subject: [PATCH 1/3] add model registration and other boilerplate --- python/mlc_llm/conversation_template/qwen2.py | 14 ++++++++++++++ python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/model/model.py | 15 +++++++++++++++ python/mlc_llm/model/qwen2_vl/__init__.py | 0 python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py | 0 python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py | 0 python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py | 0 .../model/qwen2_vl/qwen2_vl_quantization.py | 0 python/mlc_llm/model/vision/image_processing.py | 2 +- 9 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 python/mlc_llm/model/qwen2_vl/__init__.py create mode 100644 python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py create mode 100644 python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py create mode 100644 python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py create mode 100644 python/mlc_llm/model/qwen2_vl/qwen2_vl_quantization.py diff --git a/python/mlc_llm/conversation_template/qwen2.py b/python/mlc_llm/conversation_template/qwen2.py index bd4082b456..890e61ff59 100644 --- a/python/mlc_llm/conversation_template/qwen2.py +++ b/python/mlc_llm/conversation_template/qwen2.py @@ -18,3 +18,17 @@ stop_token_ids=[151643, 151645], ) ) + +ConvTemplateRegistry.register_conv_template( + Conversation( + name="qwen2-vl", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n", + system_message="You are a helpful assistant.", + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|endoftext|>", "<|im_end|>"], + stop_token_ids=[151643, 151645], + ) +) \ No newline at end of file diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 29eafd90ce..8115ce705d 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -267,6 +267,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "chatml", "chatml_nosystem", "qwen2", + "qwen2-vl", "open_hermes_mistral", "neural_hermes_mistral", "llama_default", diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index b37cb8fedb..27fb47e5ce 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -38,6 +38,7 @@ from .qwen import qwen_loader, qwen_model, qwen_quantization from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization +from .qwen2_vl import qwen2_vl_loader, qwen2_vl_model, qwen2_vl_quantization from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization @@ -299,6 +300,20 @@ class Model: "ft-quant": qwen2_moe_quantization.ft_quant, }, ), + "qwen2_vl": Model( + name="qwen2_vl", + model=qwen2_vl_model.QWen2VLLMHeadModel, + config=qwen2_vl_model.QWen2VLConfig, + source={ + "huggingface-torch": qwen2_vl_loader.huggingface, + "huggingface-safetensor": qwen2_vl_loader.huggingface, + }, + quantize={ + "no-quant": qwen2_vl_quantization.no_quant, + "group-quant": qwen2_vl_quantization.group_quant, + "ft-quant": qwen2_vl_quantization.ft_quant, + }, + ), "deepseek_v2": Model( name="deepseek_v2", model=deepseek_v2_model.DeepseekV2ForCausalLM, diff --git a/python/mlc_llm/model/qwen2_vl/__init__.py b/python/mlc_llm/model/qwen2_vl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_quantization.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_quantization.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/vision/image_processing.py b/python/mlc_llm/model/vision/image_processing.py index bccd8ed2a0..7191b59275 100644 --- a/python/mlc_llm/model/vision/image_processing.py +++ b/python/mlc_llm/model/vision/image_processing.py @@ -84,7 +84,7 @@ def get_output_image_size(image: Tensor): assert False, "not supported resize parameter" (new_h, new_w) = get_output_image_size(image) - out = op.interpolate(image, (new_h, new_w), data_layout="NCHW", mode="linear") + out = op.interpolate(image, (new_h, new_w), data_layout="NCHW", mode="bicubic") return out # pylint: disable=too-many-arguments,too-many-locals From 8ab0824bf14004ed526a02308cc2a10f9a42b49f Mon Sep 17 00:00:00 2001 From: Nihal John George Date: Mon, 10 Feb 2025 13:58:31 -0500 Subject: [PATCH 2/3] Add image preproc (till before patches for ViT) --- .../mlc_llm/model/qwen2_vl/qwen2_vl_image.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py index e69de29bb2..297816bdbc 100644 --- a/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py +++ b/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py @@ -0,0 +1,87 @@ +# qwen2_vl_image.py +# Contains image preprocessing, ViT definition, and other image-related operations. + +from typing import List, Tuple + +from tvm import relax, te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm.model.vision import ImageProcessor + + +class QWen2VLImagePreprocessor(nn.Module): + def __init__( + self, + do_resize: bool = True, + resample: str = "bicubic", + do_rescale: bool = True, + rescale_factor: float = 1/255.0, + do_normalize: bool = True, + image_mean: Tensor = OPENAI_CLIP_MEAN, + image_std: Tensor = OPENAI_CLIP_STD, + min_pixels: int = 56*56, + max_pixels: int = 28*28*1280, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + ): + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.image_processor = ImageProcessor() + + def smart_resize(height: int, width: int, factor: int=28, min_pixels: int = 56*56, max_pixels: int = 14*14*4*1280) -> Tuple[int, int]: + """ + Rescales the image, similar to the Huggingface implementation, so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = tir.round(height / factor) * factor + w_bar = tir.round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = tir.sqrt((height * width) / max_pixels) + h_bar = tir.floor(height / beta / factor) * factor + w_bar = tir.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = tir.sqrt(min_pixels / (height * width)) + h_bar = tir.ceil(height * beta / factor) * factor + w_bar = tir.ceil(width * beta / factor) * factor + return h_bar, w_bar + + def forward(self, pixel_values: Tensor, resized_height, resized_width) -> Tensor: + pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW + if self.do_resize: + hbar, wbar = self.smart_resize(pixel_values.shape[2], pixel_values.shape[3], factor=self.patch_size*self.merge_size) + pixel_values = self.image_processor.resize(pixel_values, params={"height": hbar, "width": wbar}) + if self.do_rescale: + pixel_values = self.image_processor.rescale(pixel_values, factor=self.rescale_factor) + if self.do_normalize: + pixel_values = self.image_processor.normalize(pixel_values, mean=self.image_mean, std=self.image_std) + + # TODO no padding in HF but do we need? + return pixel_values + +class QWen2VLVisionTransformer: + # TODO Not CLIP, uses original ViT (CLIP also bases on this) + def __init__(self): + pass + From 86b4de97bf20d52c4f5fedbfc2c4a764216cf636 Mon Sep 17 00:00:00 2001 From: Nihal John George Date: Mon, 10 Mar 2025 13:09:23 -0400 Subject: [PATCH 3/3] Add impl (WIP testing, bugfixes) --- .../mlc_llm/model/qwen2_vl/qwen2_vl_config.py | 86 +++++ .../mlc_llm/model/qwen2_vl/qwen2_vl_image.py | 294 ++++++++++++++++-- .../mlc_llm/model/qwen2_vl/qwen2_vl_loader.py | 140 +++++++++ .../mlc_llm/model/qwen2_vl/qwen2_vl_model.py | 249 +++++++++++++++ .../model/qwen2_vl/qwen2_vl_quantization.py | 111 +++++++ tests/python/model/test_qwen2_vl.py | 189 +++++++++++ 6 files changed, 1040 insertions(+), 29 deletions(-) create mode 100644 python/mlc_llm/model/qwen2_vl/qwen2_vl_config.py create mode 100644 tests/python/model/test_qwen2_vl.py diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_config.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_config.py new file mode 100644 index 0000000000..ffaf349dd6 --- /dev/null +++ b/python/mlc_llm/model/qwen2_vl/qwen2_vl_config.py @@ -0,0 +1,86 @@ +"""Configuration classes for Qwen2 VL model.""" + +import dataclasses +from typing import Any, Dict, Optional + +from mlc_llm.support.config import ConfigBase +from mlc_llm.model.qwen2.qwen2_model import QWen2Config + +@dataclasses.dataclass +class QWen2VLVisionConfig(ConfigBase): + """Configuration for the vision part of Qwen2 VL.""" + + hidden_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + patch_size: int = 14 + merge_size: int = 2 + image_size: int = 448 # Default max image size + num_channels: int = 3 + layer_norm_eps: float = 1e-6 + max_patches: int = 1024 # Maximum number of patches after merging + hidden_act: str = "gelu" + dtype: str = "float32" + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + num_rope_scales: int = 4 + rope_theta: float = 10000 + +@dataclasses.dataclass +class QWen2VLConfig(QWen2Config): + """Configuration for the complete Qwen2 VL model.""" + + vision_config: Optional[QWen2VLVisionConfig] = None + image_size: int = 448 + min_image_size: int = 224 + max_image_size: int = 448 + patch_size: int = 14 + merge_size: int = 2 + temporal_patch_size: int = 2 + min_patch_size: int = 14 + max_patch_size: int = 28 + min_pixels: int = 56*56 + max_pixels: int = 28*28*1280 + dtype: str = "float32" + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + # First run parent class post init + super().__post_init__() + + # Set up vision config if not provided + if self.vision_config is None: + self.vision_config = QWen2VLVisionConfig( + hidden_size=1024, # Vision hidden size + intermediate_size=4096, # Vision MLP size + num_hidden_layers=24, # Number of vision transformer layers + num_attention_heads=16, # Number of vision attention heads + patch_size=self.patch_size, + merge_size=self.merge_size, + image_size=self.image_size, + layer_norm_eps=1e-6, + dtype=self.dtype, + ) + + # Validate configuration + if self.patch_size < self.min_patch_size or self.patch_size > self.max_patch_size: + raise ValueError( + f"patch_size must be between {self.min_patch_size} and {self.max_patch_size}, " + f"got {self.patch_size}" + ) + + if self.image_size < self.min_image_size or self.image_size > self.max_image_size: + raise ValueError( + f"image_size must be between {self.min_image_size} and {self.max_image_size}, " + f"got {self.image_size}" + ) + + # Calculate maximum patches based on image size and patch size + max_h = self.max_image_size // (self.patch_size * self.merge_size) + max_w = self.max_image_size // (self.patch_size * self.merge_size) + self.vision_config.max_patches = max_h * max_w + + # Add any additional kwargs + for k, v in self.kwargs.items(): + if not hasattr(self, k): + setattr(self, k, v) \ No newline at end of file diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py index 297816bdbc..58fe0b2be6 100644 --- a/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py +++ b/python/mlc_llm/model/qwen2_vl/qwen2_vl_image.py @@ -1,33 +1,39 @@ # qwen2_vl_image.py # Contains image preprocessing, ViT definition, and other image-related operations. -from typing import List, Tuple +from typing import List, Optional, Tuple from tvm import relax, te, tir from tvm.relax.frontend import nn -from tvm.relax.frontend.nn import Tensor, op +from tvm.relax.frontend.nn import Module, Tensor, op from mlc_llm.model.vision import ImageProcessor +from mlc_llm.support.config import ConfigBase +from .qwen2_vl_config import QWen2VLConfig, QWen2VLVisionConfig +from mlc_llm.nn import RopeMode, apply_rotary_emb, precompute_rope_cache +# Constants from CLIP +OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] +OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] class QWen2VLImagePreprocessor(nn.Module): + """Image preprocessing for QWen2 VL, including smart resize and normalization.""" def __init__( self, do_resize: bool = True, - resample: str = "bicubic", do_rescale: bool = True, rescale_factor: float = 1/255.0, do_normalize: bool = True, - image_mean: Tensor = OPENAI_CLIP_MEAN, - image_std: Tensor = OPENAI_CLIP_STD, + image_mean: List[float] = OPENAI_CLIP_MEAN, + image_std: List[float] = OPENAI_CLIP_STD, min_pixels: int = 56*56, max_pixels: int = 28*28*1280, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2, ): + super().__init__() self.do_resize = do_resize - self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize @@ -40,48 +46,278 @@ def __init__( self.merge_size = merge_size self.image_processor = ImageProcessor() - def smart_resize(height: int, width: int, factor: int=28, min_pixels: int = 56*56, max_pixels: int = 14*14*4*1280) -> Tuple[int, int]: + def smart_resize(self, height: int, width: int, factor: int = 28) -> Tuple[int, int]: """ - Rescales the image, similar to the Huggingface implementation, so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - 3. The aspect ratio of the image is maintained as closely as possible. + Rescales the image dimensions to meet the following conditions: + 1. Both dimensions are divisible by factor (patch_size * merge_size) + 2. Total pixels within [min_pixels, max_pixels] + 3. Maintains aspect ratio as much as possible """ - if height < factor or width < factor: raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") elif max(height, width) / min(height, width) > 200: raise ValueError( f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) + + # Round to nearest multiple of factor h_bar = tir.round(height / factor) * factor w_bar = tir.round(width / factor) * factor - if h_bar * w_bar > max_pixels: - beta = tir.sqrt((height * width) / max_pixels) + + # Scale if outside pixel bounds + if h_bar * w_bar > self.max_pixels: + beta = tir.sqrt((height * width) / self.max_pixels) h_bar = tir.floor(height / beta / factor) * factor w_bar = tir.floor(width / beta / factor) * factor - elif h_bar * w_bar < min_pixels: - beta = tir.sqrt(min_pixels / (height * width)) + elif h_bar * w_bar < self.min_pixels: + beta = tir.sqrt(self.min_pixels / (height * width)) h_bar = tir.ceil(height * beta / factor) * factor w_bar = tir.ceil(width * beta / factor) * factor + return h_bar, w_bar - def forward(self, pixel_values: Tensor, resized_height, resized_width) -> Tensor: - pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW + def forward(self, pixel_values: Tensor) -> Tensor: + """Process images through resize, rescale and normalize steps.""" + # Convert NHWC to NCHW + pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) + if self.do_resize: - hbar, wbar = self.smart_resize(pixel_values.shape[2], pixel_values.shape[3], factor=self.patch_size*self.merge_size) - pixel_values = self.image_processor.resize(pixel_values, params={"height": hbar, "width": wbar}) + factor = self.patch_size * self.merge_size + h_bar, w_bar = self.smart_resize( + pixel_values.shape[2], + pixel_values.shape[3], + factor=factor + ) + pixel_values = self.image_processor.resize( + pixel_values, + params={"height": h_bar, "width": w_bar} + ) + if self.do_rescale: - pixel_values = self.image_processor.rescale(pixel_values, factor=self.rescale_factor) + pixel_values = self.image_processor.rescale( + pixel_values, + rescale_factor=self.rescale_factor + ) + if self.do_normalize: - pixel_values = self.image_processor.normalize(pixel_values, mean=self.image_mean, std=self.image_std) - - # TODO no padding in HF but do we need? + pixel_values = self.image_processor.normalize(pixel_values) + return pixel_values -class QWen2VLVisionTransformer: - # TODO Not CLIP, uses original ViT (CLIP also bases on this) - def __init__(self): - pass +class QWen2VLVisionEmbeddings(nn.Module): + """Patch and position embeddings with 2D patch merging for vision input.""" + def __init__(self, config: QWen2VLVisionConfig): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.merge_size = config.merge_size + + # Patch embedding + self.patch_embed = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=True + ) + + # Position embedding will be added after patch merging + self.pos_embed = nn.Parameter((config.max_patches, config.hidden_size)) + + def merge_patches(self, patches: Tensor) -> Tensor: + """Merge 2x2 neighboring patches.""" + B, H, W, C = patches.shape + + # Reshape to group 2x2 patches: B, H/2, 2, W/2, 2, C + patches = op.reshape(patches, (B, H//2, 2, W//2, 2, C)) + + # Permute to B, H/2, W/2, 2, 2, C + patches = op.permute_dims(patches, (0, 1, 3, 2, 4, 5)) + + # Merge the 2x2 patches: B, H/2, W/2, 4*C + patches = op.reshape(patches, (B, H//2, W//2, 4*C)) + + return patches + + def forward(self, pixel_values: Tensor) -> Tensor: + # Get patches: B, C, H, W -> B, hidden_size, H//patch_size, W//patch_size + patches = self.patch_embed(pixel_values) + + B, C, H, W = patches.shape + + # Reshape to B, H, W, C for patch merging + patches = op.permute_dims(patches, (0, 2, 3, 1)) + + # Merge 2x2 patches + patches = self.merge_patches(patches) + + # Reshape to sequence: B, (H/2)*(W/2), 4*hidden_size + patches = op.reshape(patches, (B, -1, 4*C)) + + # Add position embeddings + seq_length = patches.shape[1] + position_embeddings = self.pos_embed[:seq_length] + patches = patches + position_embeddings + + return patches + +class QWen2VLVisionTransformer(nn.Module): + """Vision transformer with patch merging for QWen2 VL.""" + def __init__(self, config: QWen2VLConfig): + super().__init__() + self.config = config + + # Embeddings + self.embeddings = QWen2VLVisionEmbeddings(config.vision_config) + + # Transformer layers + self.layers = nn.ModuleList([ + QWen2VLVisionLayer(config.vision_config) for _ in range(config.vision_config.num_hidden_layers) + ]) + + # Final layernorm + self.post_layernorm = nn.LayerNorm( + config.vision_config.hidden_size * 4, # *4 because of patch merging + eps=config.vision_config.layer_norm_eps + ) + + def forward(self, pixel_values: Tensor) -> Tensor: + hidden_states = self.embeddings(pixel_values) + + # Apply transformer layers + for layer in self.layers: + hidden_states = layer(hidden_states) + + # Final layernorm + hidden_states = self.post_layernorm(hidden_states) + + return hidden_states + +class QWen2VLVisionLayer(nn.Module): + """Single transformer layer for vision processing.""" + def __init__(self, config: QWen2VLVisionConfig): + super().__init__() + hidden_size = config.hidden_size * 4 # *4 because of patch merging + self.attention = QWen2VLVisionAttention(config, hidden_size) + self.mlp = QWen2VLVisionMLP(config, hidden_size) + self.layernorm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.layernorm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor) -> Tensor: + # Self attention with residual + residual = hidden_states + hidden_states = self.layernorm1(hidden_states) + hidden_states = self.attention(hidden_states) + hidden_states = residual + hidden_states + + # MLP with residual + residual = hidden_states + hidden_states = self.layernorm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + +class QWen2VLVisionAttention(nn.Module): + """Multi-head attention with M-ROPE for vision transformer.""" + def __init__(self, config: QWen2VLVisionConfig, hidden_size: int): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = hidden_size + self.head_dim = hidden_size // config.num_attention_heads + + self.q_proj = nn.Linear(hidden_size, hidden_size) + self.k_proj = nn.Linear(hidden_size, hidden_size) + self.v_proj = nn.Linear(hidden_size, hidden_size) + self.o_proj = nn.Linear(hidden_size, hidden_size) + + # M-ROPE parameters + self.rope_mode = RopeMode.NORMAL # Using normal ROPE mode but with multiple scales + self.num_rope_scales = 4 # Number of frequency bands for M-ROPE + self.rope_scale = 1.0 + self.rope_theta = 10000 + self.max_position_embeddings = config.max_patches + + # Initialize rope cache with multiple scales + self.rope_cache = {} + for scale_idx in range(self.num_rope_scales): + scale = 1.0 / (2 ** scale_idx) # Geometric progression of scales + self.rope_cache[f"scale_{scale_idx}"] = precompute_rope_cache( + dim=self.head_dim, + num_heads=self.num_attention_heads, + max_seq_len=self.max_position_embeddings, + rope_mode=self.rope_mode, + rope_scale=scale * self.rope_scale, + rope_theta=self.rope_theta, + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + B, L, _ = hidden_states.shape + + # Project Q, K, V + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape for attention + q = op.reshape(q, (B, L, self.num_attention_heads, self.head_dim)) + k = op.reshape(k, (B, L, self.num_attention_heads, self.head_dim)) + v = op.reshape(v, (B, L, self.num_attention_heads, self.head_dim)) + + # Apply M-ROPE: split heads into groups and apply different scales + heads_per_scale = self.num_attention_heads // self.num_rope_scales + q_scaled = [] + k_scaled = [] + + for scale_idx in range(self.num_rope_scales): + start_idx = scale_idx * heads_per_scale + end_idx = start_idx + heads_per_scale + + # Get current scale's rope cache + rope_cache = self.rope_cache[f"scale_{scale_idx}"] + + # Apply rotary embeddings with current scale + q_part = q[:, :, start_idx:end_idx, :] + k_part = k[:, :, start_idx:end_idx, :] + + q_part_scaled = apply_rotary_emb( + q_part, + rope_cache, + offset=0, + num_heads=heads_per_scale, + ) + k_part_scaled = apply_rotary_emb( + k_part, + rope_cache, + offset=0, + num_heads=heads_per_scale, + ) + + q_scaled.append(q_part_scaled) + k_scaled.append(k_part_scaled) + + # Concatenate all scaled versions + q = op.concatenate(q_scaled, axis=2) + k = op.concatenate(k_scaled, axis=2) + + # Compute attention with scaled Q, K, V + attn_output = op_ext.attention(q, k, v) + + # Project output + output = self.o_proj(attn_output) + return output + +class QWen2VLVisionMLP(nn.Module): + """MLP layer for vision transformer.""" + def __init__(self, config: QWen2VLVisionConfig, hidden_size: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, hidden_size) + self.act = nn.GELU() + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py index e69de29bb2..8c3a622d16 100644 --- a/python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py +++ b/python/mlc_llm/model/qwen2_vl/qwen2_vl_loader.py @@ -0,0 +1,140 @@ +""" +This file specifies how MLC's Qwen2 VL parameters map from HuggingFace format. +""" + +import functools +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .qwen2_vl_config import QWen2VLConfig +from .qwen2_vl_model import QWen2VLForCausalLM + +def huggingface(model_config: QWen2VLConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping from MLC LLM parameters to HuggingFace parameters. + + Parameters + ---------- + model_config : QWen2VLConfig + The configuration of the Qwen2 VL model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = QWen2VLForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + # Vision model mappings + def _add_vision(mlc_name: str, hf_name: str = None): + if hf_name is None: + hf_name = mlc_name + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + # Vision transformer layers + vision_prefix = "vision_model" + for i in range(model_config.vision_config.num_hidden_layers): + layer_prefix = f"{vision_prefix}.layers.{i}" + _add_vision(f"{layer_prefix}.layernorm1.weight") + _add_vision(f"{layer_prefix}.layernorm1.bias") + _add_vision(f"{layer_prefix}.layernorm2.weight") + _add_vision(f"{layer_prefix}.layernorm2.bias") + + # Attention weights + _add_vision(f"{layer_prefix}.attention.q_proj.weight") + _add_vision(f"{layer_prefix}.attention.q_proj.bias") + _add_vision(f"{layer_prefix}.attention.k_proj.weight") + _add_vision(f"{layer_prefix}.attention.k_proj.bias") + _add_vision(f"{layer_prefix}.attention.v_proj.weight") + _add_vision(f"{layer_prefix}.attention.v_proj.bias") + _add_vision(f"{layer_prefix}.attention.o_proj.weight") + _add_vision(f"{layer_prefix}.attention.o_proj.bias") + + # MLP weights + _add_vision(f"{layer_prefix}.mlp.fc1.weight") + _add_vision(f"{layer_prefix}.mlp.fc1.bias") + _add_vision(f"{layer_prefix}.mlp.fc2.weight") + _add_vision(f"{layer_prefix}.mlp.fc2.bias") + + # Vision embeddings and final layer norm + _add_vision(f"{vision_prefix}.embeddings.patch_embed.weight") + _add_vision(f"{vision_prefix}.embeddings.patch_embed.bias") + _add_vision(f"{vision_prefix}.embeddings.pos_embed") + _add_vision(f"{vision_prefix}.post_layernorm.weight") + _add_vision(f"{vision_prefix}.post_layernorm.bias") + + # Vision projection + _add_vision("vision_projection.linear_1.weight", "visual_proj.0.weight") + _add_vision("vision_projection.linear_1.bias", "visual_proj.0.bias") + _add_vision("vision_projection.linear_2.weight", "visual_proj.2.weight") + _add_vision("vision_projection.linear_2.bias", "visual_proj.2.bias") + + # Language model mappings + for i in range(model_config.num_hidden_layers): + # Map attention weights + attn = f"language_model.layers.{i}.self_attn" + for weight_type in ["weight", "bias"]: + mlc_name = f"{attn}.c_attn.{weight_type}" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{weight_type}", + f"{attn}.k_proj.{weight_type}", + f"{attn}.v_proj.{weight_type}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Map MLP weights + mlp = f"language_model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Map remaining parameters directly + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py index e69de29bb2..a4a3d00343 100644 --- a/python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py +++ b/python/mlc_llm/model/qwen2_vl/qwen2_vl_model.py @@ -0,0 +1,249 @@ +"""Implementation of the Qwen2 VL model.""" + +from typing import Optional, Tuple + +from tvm import relax, te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.model.qwen2.qwen2_model import QWen2Model +from mlc_llm.nn import PagedKVCache + +from .qwen2_vl_config import QWen2VLConfig +from .qwen2_vl_image import QWen2VLImagePreprocessor, QWen2VLVisionTransformer + +class QWen2VLProjection(nn.Module): + """Projects vision features to language model dimension.""" + def __init__(self, config: QWen2VLConfig): + super().__init__() + # Input is 4x vision hidden size due to patch merging + vision_hidden_size = config.vision_config.hidden_size * 4 + + # Project to language model dimension with two-layer MLP + self.linear_1 = nn.Linear(vision_hidden_size, config.hidden_size, bias=True) + self.act = nn.GELU() + self.linear_2 = nn.Linear(config.hidden_size, config.hidden_size, bias=True) + + def forward(self, image_features: Tensor) -> Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + +class QWen2VLForCausalLM(nn.Module): + """Qwen2 VL model combining vision and language capabilities.""" + + def __init__(self, config: QWen2VLConfig): + super().__init__() + self.config = config + + # Vision components + self.image_processor = QWen2VLImagePreprocessor( + min_pixels=config.min_pixels, + max_pixels=config.max_pixels, + patch_size=config.patch_size, + merge_size=config.merge_size, + temporal_patch_size=config.temporal_patch_size, + ) + self.vision_model = QWen2VLVisionTransformer(config) + self.vision_projection = QWen2VLProjection(config) + + # Language model + self.language_model = QWen2Model(config) + + # Final LM head (reuse embedding weight if tied) + self.tie_word_embeddings = config.tie_word_embeddings + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Model attributes needed for integration + self.dtype = config.dtype + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.vocab_size = config.vocab_size + self.tensor_parallel_shards = config.tensor_parallel_shards + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def image_preprocess(self, pixel_values: Tensor) -> Tensor: + """Preprocess images for the vision encoder.""" + return self.image_processor(pixel_values) + + def image_embed(self, pixel_values: Tensor) -> Tensor: + """Get image embeddings from preprocessed images.""" + # Process through vision transformer + vision_outputs = self.vision_model(pixel_values) + + # Project to language model dimension + image_embeds = self.vision_projection(vision_outputs) + + return image_embeds + + def embed(self, input_ids: Tensor): + """Get text embeddings.""" + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.language_model.embed_tokens(input_ids) + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + """Forward pass for both vision and language.""" + op_ext.configure() + + hidden_states = self.language_model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + + if self.tie_word_embeddings: + logits = self.language_model.embed_tokens.lm_head_forward(hidden_states) + else: + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + """Prefill KV cache.""" + op_ext.configure() + + def _index(x: te.Tensor): + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.language_model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + + if self.tie_word_embeddings: + logits = self.language_model.embed_tokens.lm_head_forward(hidden_states) + else: + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + """Decode step.""" + op_ext.configure() + + hidden_states = self.language_model(input_embed, paged_kv_cache) + if self.tie_word_embeddings: + logits = self.language_model.embed_tokens.lm_head_forward(hidden_states) + else: + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + """Batch prefill operation.""" + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + """Batch decode operation.""" + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + """Create paged KV cache.""" + return self.language_model.create_paged_kv_cache( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + ) + + def get_default_spec(self): + """Get the default module spec.""" + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "image_preprocess": { + "pixel_values": nn.spec.Tensor([1, "image_height", "image_width", 3], "uint8"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "image_embed": { + "pixel_values": nn.spec.Tensor([1, 3, "image_height", "image_width"], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) + diff --git a/python/mlc_llm/model/qwen2_vl/qwen2_vl_quantization.py b/python/mlc_llm/model/qwen2_vl/qwen2_vl_quantization.py index e69de29bb2..0eb6cff35d 100644 --- a/python/mlc_llm/model/qwen2_vl/qwen2_vl_quantization.py +++ b/python/mlc_llm/model/qwen2_vl/qwen2_vl_quantization.py @@ -0,0 +1,111 @@ +"""This file specifies how MLC's Qwen2 VL parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .qwen2_vl_config import QWen2VLConfig +from .qwen2_vl_model import QWen2VLForCausalLM + + +def group_quant( + model_config: QWen2VLConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen2 VL model using group quantization.""" + model: nn.Module = QWen2VLForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards + + # Vision model quantization patterns + vision_patterns = [ + # Vision transformer attention patterns + "vision_model.layers.*.attention.q_proj.weight", + "vision_model.layers.*.attention.k_proj.weight", + "vision_model.layers.*.attention.v_proj.weight", + "vision_model.layers.*.attention.o_proj.weight", + # Vision transformer MLP patterns + "vision_model.layers.*.mlp.fc1.weight", + "vision_model.layers.*.mlp.fc2.weight", + # Vision embeddings + "vision_model.embeddings.patch_embed.weight", + # Vision projection + "vision_projection.linear_1.weight", + "vision_projection.linear_2.weight", + ] + + # Add vision patterns to quantization + for pattern in vision_patterns: + quantization.add_pattern(pattern) + + # Language model patterns (from qwen2_quantization.py) + language_patterns = [ + # Attention patterns + "language_model.layers.*.self_attn.c_attn.weight", + "language_model.layers.*.self_attn.c_proj.weight", + # MLP patterns + "language_model.layers.*.mlp.gate_up_proj.weight", + "language_model.layers.*.mlp.down_proj.weight", + ] + + # Add language patterns to quantization + for pattern in language_patterns: + quantization.add_pattern(pattern) + + # Quantize the model + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: QWen2VLConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen2 VL model using FasterTransformer quantization.""" + model: nn.Module = QWen2VLForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + + # Add vision and language patterns similar to group_quant + vision_patterns = [ + "vision_model.layers.*.attention.*.weight", + "vision_model.layers.*.mlp.*.weight", + "vision_model.embeddings.patch_embed.weight", + "vision_projection.*.weight", + ] + + language_patterns = [ + "language_model.layers.*.self_attn.*.weight", + "language_model.layers.*.mlp.*.weight", + ] + + # Add patterns to quantization + for pattern in vision_patterns + language_patterns: + quantization.add_pattern(pattern) + + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: QWen2VLConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Load a Qwen2 VL model without quantization.""" + model: nn.Module = QWen2VLForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/tests/python/model/test_qwen2_vl.py b/tests/python/model/test_qwen2_vl.py new file mode 100644 index 0000000000..5cb4ac40a1 --- /dev/null +++ b/tests/python/model/test_qwen2_vl.py @@ -0,0 +1,189 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest +import numpy as np + +from tvm import relax +from tvm.relax.frontend.nn import Tensor + +from mlc_llm.model import MODEL_PRESETS, MODELS +from mlc_llm.quantization import QUANTIZATION +from mlc_llm.quantization.group_quantization import GroupQuantizeLinear + +from mlc_llm.model.qwen2_vl.qwen2_vl_config import QWen2VLConfig, QWen2VLVisionConfig +from mlc_llm.model.qwen2_vl.qwen2_vl_image import ( + QWen2VLVisionTransformer, + QWen2VLImagePreprocessor, + QWen2VLVisionAttention, +) +from mlc_llm.nn import RopeMode, precompute_rope_cache + +def test_vision_transformer(): + """Test the vision transformer components independently.""" + # Create a basic vision config + vision_config = QWen2VLVisionConfig( + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=6, + num_attention_heads=8, + patch_size=14, + merge_size=2, + image_size=448, + ) + + # Test image preprocessor + image_processor = QWen2VLImagePreprocessor() + test_image = np.random.randint(0, 255, (1, 224, 224, 3), dtype="uint8") + processed_image = image_processor(Tensor(test_image)) + assert processed_image.shape[0] == 1 # batch size + assert processed_image.shape[1] == 3 # channels + assert processed_image.shape[2] % (vision_config.patch_size * vision_config.merge_size) == 0 + assert processed_image.shape[3] % (vision_config.patch_size * vision_config.merge_size) == 0 + + # Test vision transformer + vision_model = QWen2VLVisionTransformer(vision_config) + vision_output = vision_model(processed_image) + + # Check output shape (should be B, N, 4*hidden_size due to patch merging) + expected_seq_len = (processed_image.shape[2] // vision_config.patch_size // vision_config.merge_size) * \ + (processed_image.shape[3] // vision_config.patch_size // vision_config.merge_size) + assert vision_output.shape == (1, expected_seq_len, vision_config.hidden_size * 4) + +def test_m_rope_implementation(): + """Test the M-ROPE implementation in the vision transformer.""" + # Create a basic vision config with specific M-ROPE parameters + vision_config = QWen2VLVisionConfig( + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=6, + num_attention_heads=8, + patch_size=14, + merge_size=2, + image_size=448, + num_rope_scales=4, # 4 different frequency scales + rope_theta=10000, + max_patches=256, # Maximum sequence length + ) + + # Create attention module with M-ROPE + hidden_size = vision_config.hidden_size * 4 # After patch merging + attention = QWen2VLVisionAttention(vision_config, hidden_size) + + # Verify rope cache creation for each scale + assert len(attention.rope_cache) == vision_config.num_rope_scales + + # Check that each scale has a different frequency + for scale_idx in range(vision_config.num_rope_scales): + scale_key = f"scale_{scale_idx}" + assert scale_key in attention.rope_cache + + # Verify the scale factor is applied correctly (1.0, 0.5, 0.25, 0.125) + expected_scale = 1.0 / (2 ** scale_idx) + + # Create a reference rope cache with the expected scale + reference_cache = precompute_rope_cache( + dim=attention.head_dim, + num_heads=attention.num_attention_heads, + max_seq_len=vision_config.max_patches, + rope_mode=RopeMode.NORMAL, + rope_scale=expected_scale, + rope_theta=vision_config.rope_theta, + ) + + # Compare the first few values to verify scaling + # The rope cache contains cos and sin values that should differ by scale + assert np.allclose( + attention.rope_cache[scale_key]["cos"][0, 0].numpy(), + reference_cache["cos"][0, 0].numpy(), + rtol=1e-5 + ) + + # Test forward pass with M-ROPE + batch_size = 2 + seq_len = 64 + test_input = np.random.randn(batch_size, seq_len, hidden_size).astype("float32") + output = attention(Tensor(test_input)) + + # Check output shape + assert output.shape == (batch_size, seq_len, hidden_size) + +@pytest.mark.parametrize("model_name", ["qwen2_vl"]) +def test_qwen2_vl_creation(model_name: str): + """Test the creation of the full Qwen2 VL model.""" + model_info = MODELS["qwen2_vl"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + +@pytest.mark.parametrize("model_name", ["qwen2_vl"]) +@pytest.mark.parametrize( + "quant_name", + ["q3f16_1", "q4f16_1", "q4f32_1"], +) +def test_qwen2_vl_group_quantization(model_name: str, quant_name: str): + """Test group quantization of Qwen2 VL.""" + model_info = MODELS["qwen2_vl"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name]) + + # Check vision model quantization + assert "vision_model.embeddings.patch_embed.weight" in quant_map.param_map + assert isinstance( + model.vision_model.embeddings.patch_embed, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + + # Check vision projection quantization + assert "vision_projection.linear_1.weight" in quant_map.param_map + assert isinstance( + model.vision_projection.linear_1, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + + # Check vision transformer layers + for i in range(config.vision_config.num_hidden_layers): + # Check attention weights + assert f"vision_model.layers.{i}.attention.q_proj.weight" in quant_map.param_map + assert isinstance( + model.vision_model.layers[i].attention.q_proj, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + + # Check MLP weights + assert f"vision_model.layers.{i}.mlp.fc1.weight" in quant_map.param_map + assert isinstance( + model.vision_model.layers[i].mlp.fc1, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + + # Check language model quantization (similar to qwen2 tests) + for i in range(config.num_hidden_layers): + assert f"language_model.layers.{i}.self_attn.c_attn.weight" in quant_map.param_map + assert isinstance( + model.language_model.layers[i].self_attn.c_attn, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) + +@pytest.mark.parametrize("model_name", ["qwen2_vl"]) +@pytest.mark.parametrize( + "quant_name", + ["q0"], +) +def test_qwen2_vl_no_quantization(model_name: str, quant_name: str): + """Test no-quantization mode of Qwen2 VL.""" + model_info = MODELS["qwen2_vl"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + _, quant_map = model_info.quantize["no-quant"](config, QUANTIZATION[quant_name]) + assert len(quant_map.param_map) == 0 + assert len(quant_map.map_func) == 0 + +if __name__ == "__main__": + test_vision_transformer() + test_m_rope_implementation() + test_qwen2_vl_creation("qwen2_vl") + test_qwen2_vl_group_quantization("qwen2_vl", "q4f16_1") + test_qwen2_vl_no_quantization("qwen2_vl", "q0") \ No newline at end of file