Skip to content

Commit 997c881

Browse files
[Model] Support multi-image for Molmo (#15438)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent e42389f commit 997c881

File tree

4 files changed

+39
-35
lines changed

4 files changed

+39
-35
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ
853853
*
854854
- * `MolmoForCausalLM`
855855
* Molmo
856-
* T + I
856+
* T + I<sup>+</sup>
857857
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
858858
* ✅︎
859859
* ✅︎

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@
431431
),
432432
"molmo": VLMTestInfo(
433433
models=["allenai/Molmo-7B-D-0924"],
434-
test_type=(VLMTestType.IMAGE),
434+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
435435
prompt_formatter=identity,
436436
max_model_len=4096,
437437
max_num_seqs=2,

vllm/model_executor/models/molmo.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
is_pp_missing_parameter,
5858
make_empty_intermediate_tensors_factory, make_layers,
5959
maybe_prefix, merge_multimodal_embeddings)
60-
from .vision import select_patch_features
60+
from .vision import scatter_patch_features, select_patch_features
6161

6262
# TODO: hard-coded for now. Consider making it configurable.
6363
VIT_LAYERS = [-2, -9]
@@ -71,29 +71,29 @@
7171

7272

7373
class MolmoImageInputs(TypedDict):
74-
images: Union[torch.Tensor, List[torch.Tensor]]
74+
images: Union[torch.Tensor, list[torch.Tensor]]
7575
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
7676

77-
image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]]
77+
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
7878
"""Shape: `(batch_size, num_crops, num_patch)`"""
7979

80-
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
80+
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
8181
"""
8282
A boolean mask indicating which image features correspond
8383
to patch tokens.
8484
8585
Shape: `(batch_size, num_crops, num_patch)`
8686
"""
8787

88-
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
88+
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
8989
"""
9090
A boolean mask indicating which image embeddings correspond
9191
to patch tokens.
9292
9393
Shape: `(batch_size, num_embeds)`
9494
"""
9595

96-
num_crops: torch.Tensor
96+
num_crops: Union[torch.Tensor, list[torch.Tensor]]
9797
"""Shape: `(batch_size, num_images)`"""
9898

9999

@@ -1144,13 +1144,7 @@ def __call__(
11441144

11451145
image_input_idx = outputs.pop("image_input_idx", None)
11461146
if image_input_idx is not None:
1147-
input_is_patch = input_ids == self.image_patch_id
1148-
image_input_idx_flat: torch.Tensor = image_input_idx.view(-1)
1149-
image_valid_flat = image_input_idx_flat >= 0
1150-
feat_is_patch_flat = image_valid_flat.clone()
1151-
feat_is_patch_flat[image_valid_flat] = (
1152-
input_is_patch[image_input_idx_flat[image_valid_flat]])
1153-
feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape)
1147+
feat_is_patch = image_input_idx >= 0
11541148

11551149
input_is_embed = torch.isin(
11561150
input_ids,
@@ -1165,6 +1159,17 @@ def __call__(
11651159
embed_is_patch = embed_ids == self.image_patch_id
11661160
assert embed_is_patch.sum() == feat_is_patch.sum()
11671161

1162+
# image_tokens = extra_joint + joint
1163+
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
1164+
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
1165+
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
1166+
assert len(embed_start) == len(embed_end) == len(images)
1167+
1168+
embed_is_patch = [
1169+
embed_is_patch[start:end + 1]
1170+
for start, end in zip(embed_start, embed_end)
1171+
]
1172+
11681173
tilings = [
11691174
self.select_tiling(
11701175
image_width=image.size[0],
@@ -1180,7 +1185,7 @@ def __call__(
11801185
outputs["num_crops"] = num_crops
11811186
outputs["img_patch_id"] = self.image_patch_id
11821187

1183-
return BatchFeature(outputs, tensor_type=return_tensors)
1188+
return BatchFeature(outputs)
11841189

11851190

11861191
class MolmoProcessingInfo(BaseProcessingInfo):
@@ -1190,9 +1195,7 @@ def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
11901195
return MolmoProcessorWrapper(processor)
11911196

11921197
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
1193-
# TODO: Investigate different `embed_is_patch` between cache/no-cache
1194-
# in multi-image case
1195-
return {"image": 1}
1198+
return {"image": None}
11961199

11971200
def get_mm_max_tokens_per_item(
11981201
self,
@@ -1325,7 +1328,7 @@ def _get_mm_fields_config(
13251328
"image", num_crops),
13261329
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
13271330
"image", num_crops),
1328-
embed_is_patch=MultiModalFieldConfig.shared("image", num_images),
1331+
embed_is_patch=MultiModalFieldConfig.batched("image"),
13291332
num_crops=MultiModalFieldConfig.batched("image"),
13301333
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
13311334
)
@@ -1499,7 +1502,7 @@ def _parse_and_validate_image_input(
14991502
def _process_image_input(
15001503
self,
15011504
image_input: MolmoImageInputs,
1502-
) -> Union[torch.Tensor, List[torch.Tensor]]:
1505+
) -> Union[torch.Tensor, list[torch.Tensor]]:
15031506
if isinstance(image_input["images"], list):
15041507
# Call the vision backbone on the whole batch at once
15051508
images_flat = flatten_bn(image_input["images"], concat=True)
@@ -1530,7 +1533,7 @@ def _get_mm_embeds(
15301533
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
15311534
num_crops: torch.Tensor, # Shape: (num_images,)
15321535
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
1533-
) -> list[torch.Tensor]:
1536+
) -> tuple[torch.Tensor, ...]:
15341537
"""
15351538
Scatter the patch features into a contiguous tensor that corresponds
15361539
to the embedding tokens defined by the multimodal processor.
@@ -1565,16 +1568,12 @@ def _get_mm_embeds(
15651568
feats_per_image = features.split(num_crops_per_image)
15661569
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
15671570

1568-
_, _, embed_dim = features.shape
1569-
(num_embeds, ) = embed_is_patch.shape
1570-
1571-
embeds_in_batch = list[torch.Tensor]()
1572-
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
1573-
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
1574-
embeds[embed_is_patch] = feats[f_is_patch]
1575-
embeds_in_batch.append(embeds)
1571+
features = torch.cat([
1572+
feats[f_is_patch]
1573+
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
1574+
])
15761575

1577-
return embeds_in_batch
1576+
return scatter_patch_features(features, embed_is_patch)
15781577

15791578
def get_multimodal_embeddings(
15801579
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

vllm/model_executor/models/vision.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
155155

156156
def scatter_patch_features(
157157
features: torch.Tensor,
158-
embed_is_patch: torch.Tensor,
158+
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
159159
) -> tuple[torch.Tensor, ...]:
160160
"""
161161
Scatter the patch features into a contiguous tensor that corresponds
@@ -194,14 +194,19 @@ def scatter_patch_features(
194194
The resulting embedding tensor is:
195195
[ nan p1 p2 nan p3 p4 nan nan ]
196196
"""
197-
num_images, num_embeds = embed_is_patch.shape
198-
num_embeds_per_image = [num_embeds] * num_images
197+
num_embeds_per_image = [
198+
e_is_patch.numel() for e_is_patch in embed_is_patch
199+
]
200+
if isinstance(embed_is_patch, torch.Tensor):
201+
embed_is_patch_flat = embed_is_patch.view(-1)
202+
else:
203+
embed_is_patch_flat = torch.cat(embed_is_patch)
199204

200205
embeds_flat = features.new_full(
201206
(sum(num_embeds_per_image), features.shape[-1]),
202207
fill_value=torch.nan,
203208
)
204-
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
209+
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
205210

206211
return embeds_flat.split(num_embeds_per_image)
207212

0 commit comments

Comments
 (0)