Skip to content

Commit ecc2cf4

Browse files
authored
Unified Pixtral (#164)
1 parent 82d5104 commit ecc2cf4

File tree

5 files changed

+226
-60
lines changed

5 files changed

+226
-60
lines changed

mlx_engine/model_kit/model_kit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mlx_engine.logging import log_info, log_warn
1313
from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn
1414
from mlx_engine.model_kit.vision_add_ons.gemma3 import Gemma3VisionAddOn
15+
from mlx_engine.model_kit.vision_add_ons.pixtral import PixtralVisionAddOn
1516
from mlx_engine.utils.kv_cache_quantization import get_kv_cache_quantization_params
1617
from mlx_engine.utils.prompt_processing import process_prompt_text_only
1718

@@ -33,6 +34,7 @@ class ModelKit:
3334

3435
VISION_ADD_ON_MAP = {
3536
"gemma3": Gemma3VisionAddOn,
37+
"pixtral": PixtralVisionAddOn,
3638
}
3739

3840
# model state tracking

mlx_engine/model_kit/vision_add_ons/gemma3.py

Lines changed: 20 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import glob
2-
import json
31
from typing import List
2+
from pathlib import Path
43

54
from mlx import nn
5+
import mlx.core as mx
66

77
from mlx_vlm.models.gemma3 import (
88
VisionModel as Gemma3VisionTower,
@@ -12,76 +12,36 @@
1212
Model as Gemma3CombinedModel, # for prepare_inputs_for_multimodal
1313
)
1414
from mlx_vlm.models.gemma3.gemma3 import Gemma3MultiModalProjector
15-
from mlx_vlm.utils import sanitize_weights, load_processor, get_class_predicate
1615

17-
from pathlib import Path
18-
import mlx.core as mx
19-
20-
from mlx_engine.logging import log_info
16+
from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn
2117
from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import (
2218
common_process_prompt_with_images,
2319
)
24-
from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn
20+
from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon
2521

2622

27-
class Gemma3VisionAddOn(BaseVisionAddOn, nn.Module):
23+
class Gemma3VisionAddOn(BaseVisionAddOn):
2824
"""
29-
Vision add-on for Gemma3 model. Uses mlx-vlm vision components of Gemma3
25+
Vision add-on for Gemma3 model. Uses mlx-vlm vision components of Gemma3.
3026
"""
3127

3228
GEMMA3_LOG_PREFIX = "Gemma3VisionAddOn"
3329

3430
def __init__(self, model_path: Path):
31+
"""Initialize Gemma3VisionAddOn with vision components loaded from the given path."""
3532
super().__init__()
36-
config_dict = json.loads((model_path / "config.json").read_text())
37-
self.config = Gemma3ModelConfig.from_dict(config_dict)
38-
self.config.vision_config = Gemma3VisionConfig.from_dict(
39-
self.config.vision_config
40-
)
41-
self.config.text_config = Gemma3TextConfig.from_dict(self.config.text_config)
42-
self.vision_tower = Gemma3VisionTower(self.config.vision_config)
43-
self.multi_modal_projector = Gemma3MultiModalProjector(self.config)
44-
self.processor = load_processor(model_path=model_path, add_detokenizer=True)
45-
# load the weights for the vision tower
46-
# ref: https://github.com/Blaizzy/mlx-vlm/blob/d2391123cabac313729f9a2a8d57d396e2592f20/mlx_vlm/utils.py#L147
47-
# and https://github.com/Blaizzy/mlx-vlm/blob/d2391123cabac313729f9a2a8d57d396e2592f20/mlx_vlm/models/gemma3/gemma3.py#L86-L87
48-
weight_files = glob.glob(str(model_path / "*.safetensors"))
49-
if not weight_files:
50-
raise FileNotFoundError(
51-
f"Failed to load Gemma3 vision model: {model_path} does not contain any safetensors files"
52-
)
53-
weights = {}
54-
for wf in weight_files:
55-
weights.update(mx.load(wf))
56-
# filter out everything but weights with keys that start with "vision_tower" or "multi_modal_projector"
57-
weights = {
58-
k: v
59-
for k, v in weights.items()
60-
if k.startswith("vision_tower") or k.startswith("multi_modal_projector")
61-
}
62-
weights = sanitize_weights(
63-
Gemma3VisionTower, weights, self.config.vision_config
64-
)
65-
# perform jit quantization if needed
66-
if (quantization := config_dict.get("quantization", None)) is not None:
67-
class_predicate = get_class_predicate(skip_vision=False, weights=weights)
68-
nn.quantize(
69-
self,
70-
**quantization,
71-
class_predicate=class_predicate,
72-
)
7333

74-
# load weights using nn.Module method
75-
self.load_weights(list(weights.items()))
76-
# hardcode lazy loading to false for now, always load weights to memory here
77-
lazy = False
78-
if not lazy:
79-
mx.eval(self.parameters())
80-
81-
self.eval()
82-
log_info(
83-
prefix=self.GEMMA3_LOG_PREFIX,
84-
message=f"Gemma3 vision model loaded successfully from {model_path}",
34+
# Load vision model components, configuration, and processor
35+
self.vision_tower, self.multi_modal_projector, self.config, self.processor = (
36+
load_vision_addon(
37+
model_path=model_path,
38+
model_config_class=Gemma3ModelConfig,
39+
vision_config_class=Gemma3VisionConfig,
40+
text_config_class=Gemma3TextConfig,
41+
vision_tower_class=Gemma3VisionTower,
42+
multi_modal_projector_class=Gemma3MultiModalProjector,
43+
log_prefix=self.GEMMA3_LOG_PREFIX,
44+
)
8545
)
8646

8747
def compute_embeddings(
@@ -90,6 +50,7 @@ def compute_embeddings(
9050
prompt_tokens: mx.array,
9151
images_b64: List[str],
9252
) -> mx.array:
53+
"""Compute embeddings for text with images."""
9354
input_ids, pixel_values, attention_mask, other_model_inputs = (
9455
common_process_prompt_with_images(
9556
prompt_tokens=prompt_tokens,
@@ -105,6 +66,7 @@ def compute_embeddings(
10566
pixel_values.transpose(0, 2, 3, 1).astype(input_embeddings.dtype),
10667
output_hidden_states=True,
10768
)
69+
10870
# Format image features
10971
image_features = hidden_state.astype(pixel_values.dtype)
11072
image_features = self.multi_modal_projector(image_features)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import glob
2+
import json
3+
from pathlib import Path
4+
from typing import Any, Tuple, Type
5+
6+
import mlx.core as mx
7+
from mlx import nn
8+
9+
from mlx_vlm.utils import sanitize_weights, load_processor, get_class_predicate
10+
from mlx_engine.logging import log_info
11+
12+
13+
def load_vision_addon(
14+
model_path: Path,
15+
model_config_class: Any,
16+
vision_config_class: Any,
17+
text_config_class: Any,
18+
vision_tower_class: Type[nn.Module],
19+
multi_modal_projector_class: Type[nn.Module],
20+
log_prefix: str,
21+
) -> Tuple[nn.Module, nn.Module, Any, Any]:
22+
"""
23+
Load vision add-on components, configuration, and processor.
24+
25+
Args:
26+
model_path: Path to the model directory
27+
model_config_class: Configuration class for the model
28+
vision_config_class: Configuration class for vision component
29+
text_config_class: Configuration class for text component
30+
vision_tower_class: The vision tower model class
31+
multi_modal_projector_class: The multi-modal projector class
32+
log_prefix: Prefix for logging messages
33+
34+
Returns:
35+
Tuple containing:
36+
- The vision tower module
37+
- The multi-modal projector module
38+
- The model configuration
39+
- The processor for handling images and text
40+
"""
41+
# Load and parse configuration
42+
config_path = model_path / "config.json"
43+
if not config_path.exists():
44+
raise FileNotFoundError(f"Configuration file not found at {config_path}")
45+
46+
config_dict = json.loads(config_path.read_text())
47+
config = model_config_class.from_dict(config_dict)
48+
config.vision_config = vision_config_class.from_dict(config.vision_config)
49+
config.text_config = text_config_class.from_dict(config.text_config)
50+
51+
# Create model components
52+
vision_tower = vision_tower_class(config.vision_config)
53+
multi_modal_projector = multi_modal_projector_class(config)
54+
55+
# Combine components into a container module for loading weights
56+
class VisionComponents(nn.Module):
57+
def __init__(self):
58+
super().__init__()
59+
self.vision_tower = vision_tower
60+
self.multi_modal_projector = multi_modal_projector
61+
62+
components = VisionComponents()
63+
64+
# Load processor
65+
processor = load_processor(model_path=model_path, add_detokenizer=True)
66+
67+
# Load model weights
68+
weight_files = glob.glob(str(model_path / "*.safetensors"))
69+
if not weight_files:
70+
raise FileNotFoundError(
71+
f"Failed to load vision add-on: {model_path} does not contain any safetensors files"
72+
)
73+
74+
# Load and filter weights
75+
weights = {}
76+
for wf in weight_files:
77+
weights.update(mx.load(wf))
78+
79+
# Filter only vision-related weights
80+
vision_weights = {
81+
k: v
82+
for k, v in weights.items()
83+
if k.startswith("vision_tower") or k.startswith("multi_modal_projector")
84+
}
85+
86+
# Sanitize weights for vision tower
87+
vision_weights = sanitize_weights(
88+
vision_tower_class, vision_weights, config.vision_config
89+
)
90+
91+
# Apply quantization if specified in config
92+
if (quantization := config_dict.get("quantization", None)) is not None:
93+
class_predicate = get_class_predicate(skip_vision=False, weights=vision_weights)
94+
nn.quantize(
95+
components,
96+
**quantization,
97+
class_predicate=class_predicate,
98+
)
99+
100+
# Load weights into the model
101+
components.load_weights(list(vision_weights.items()))
102+
103+
# Always load weights to memory here
104+
mx.eval(components.parameters())
105+
106+
# Set model to evaluation mode
107+
components.eval()
108+
109+
log_info(
110+
prefix=log_prefix,
111+
message=f"Vision add-on loaded successfully from {model_path}",
112+
)
113+
114+
return vision_tower, multi_modal_projector, config, processor
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import List
2+
from pathlib import Path
3+
4+
from mlx import nn
5+
import mlx.core as mx
6+
7+
from mlx_vlm.models.pixtral import (
8+
VisionModel as PixtralVisionTower,
9+
ModelConfig as PixtralModelConfig,
10+
VisionConfig as PixtralVisionConfig,
11+
TextConfig as PixtralTextConfig,
12+
Model as PixtralCombinedModel, # for merge_input_ids_with_image_features
13+
)
14+
from mlx_vlm.models.pixtral.pixtral import (
15+
LlavaMultiModalProjector as PixtralMultiModalProjector,
16+
)
17+
18+
from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn
19+
from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import (
20+
common_process_prompt_with_images,
21+
)
22+
from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon
23+
24+
25+
class PixtralVisionAddOn(BaseVisionAddOn):
26+
"""
27+
Vision add-on for Pixtral model. Uses mlx-vlm vision components of Pixtral.
28+
"""
29+
30+
PIXTRAL_LOG_PREFIX = "PixtralVisionAddOn"
31+
32+
def __init__(self, model_path: Path):
33+
"""Initialize PixtralVisionAddOn with vision components loaded from the given path."""
34+
super().__init__()
35+
36+
# Load vision model components, configuration, and processor
37+
self.vision_tower, self.multi_modal_projector, self.config, self.processor = (
38+
load_vision_addon(
39+
model_path=model_path,
40+
model_config_class=PixtralModelConfig,
41+
vision_config_class=PixtralVisionConfig,
42+
text_config_class=PixtralTextConfig,
43+
vision_tower_class=PixtralVisionTower,
44+
multi_modal_projector_class=PixtralMultiModalProjector,
45+
log_prefix=self.PIXTRAL_LOG_PREFIX,
46+
)
47+
)
48+
49+
def compute_embeddings(
50+
self,
51+
text_model: nn.Module,
52+
prompt_tokens: mx.array,
53+
images_b64: List[str],
54+
) -> mx.array:
55+
"""Compute embeddings for text with images."""
56+
input_ids, pixel_values, attention_mask, other_model_inputs = (
57+
common_process_prompt_with_images(
58+
prompt_tokens=prompt_tokens,
59+
images_b64=images_b64,
60+
processor=self.processor,
61+
config=self.config,
62+
)
63+
)
64+
input_embeddings = text_model.language_model.model.embed_tokens(input_ids)
65+
66+
if isinstance(pixel_values, list):
67+
pixel_values = mx.concatenate(
68+
[mx.array(pv)[None, ...] for pv in pixel_values], axis=0
69+
)
70+
if pixel_values.ndim == 3:
71+
pixel_values = pixel_values[None, ...]
72+
73+
# Process image through vision tower
74+
*_, hidden_states = self.vision_tower(
75+
pixel_values.transpose(0, 2, 3, 1),
76+
output_hidden_states=True,
77+
)
78+
# Select the hidden states from the desired layer
79+
selected_image_feature = hidden_states[self.config.vision_feature_layer]
80+
81+
# Pass image features through the multi-modal projector
82+
image_features = self.multi_modal_projector(selected_image_feature)
83+
84+
# Insert special image tokens in the input_ids
85+
final_inputs_embeds = PixtralCombinedModel.merge_input_ids_with_image_features(
86+
self.config.image_token_index, image_features, input_embeddings, input_ids
87+
)
88+
return final_inputs_embeds.squeeze(0) # remove batch dimension

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ jsonschema-specifications==2024.10.1
2222
jsonschema==4.23.0
2323
lark==1.2.2
2424
markupsafe==2.1.5
25-
mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f93589cb
26-
mlx-vlm==0.1.26
25+
mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@064c75d
26+
mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm.git@51eecac
2727
mlx==0.25.2
2828
mpmath==1.3.0
2929
multidict==6.1.0

0 commit comments

Comments
 (0)