57
57
is_pp_missing_parameter ,
58
58
make_empty_intermediate_tensors_factory , make_layers ,
59
59
maybe_prefix , merge_multimodal_embeddings )
60
- from .vision import select_patch_features
60
+ from .vision import scatter_patch_features , select_patch_features
61
61
62
62
# TODO: hard-coded for now. Consider making it configurable.
63
63
VIT_LAYERS = [- 2 , - 9 ]
71
71
72
72
73
73
class MolmoImageInputs (TypedDict ):
74
- images : Union [torch .Tensor , List [torch .Tensor ]]
74
+ images : Union [torch .Tensor , list [torch .Tensor ]]
75
75
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
76
76
77
- image_masks : Optional [Union [torch .Tensor , List [torch .Tensor ]]]
77
+ image_masks : Optional [Union [torch .Tensor , list [torch .Tensor ]]]
78
78
"""Shape: `(batch_size, num_crops, num_patch)`"""
79
79
80
- feat_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
80
+ feat_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
81
81
"""
82
82
A boolean mask indicating which image features correspond
83
83
to patch tokens.
84
84
85
85
Shape: `(batch_size, num_crops, num_patch)`
86
86
"""
87
87
88
- embed_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
88
+ embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
89
89
"""
90
90
A boolean mask indicating which image embeddings correspond
91
91
to patch tokens.
92
92
93
93
Shape: `(batch_size, num_embeds)`
94
94
"""
95
95
96
- num_crops : torch .Tensor
96
+ num_crops : Union [ torch .Tensor , list [ torch . Tensor ]]
97
97
"""Shape: `(batch_size, num_images)`"""
98
98
99
99
@@ -1144,13 +1144,7 @@ def __call__(
1144
1144
1145
1145
image_input_idx = outputs .pop ("image_input_idx" , None )
1146
1146
if image_input_idx is not None :
1147
- input_is_patch = input_ids == self .image_patch_id
1148
- image_input_idx_flat : torch .Tensor = image_input_idx .view (- 1 )
1149
- image_valid_flat = image_input_idx_flat >= 0
1150
- feat_is_patch_flat = image_valid_flat .clone ()
1151
- feat_is_patch_flat [image_valid_flat ] = (
1152
- input_is_patch [image_input_idx_flat [image_valid_flat ]])
1153
- feat_is_patch = feat_is_patch_flat .view (* image_input_idx .shape )
1147
+ feat_is_patch = image_input_idx >= 0
1154
1148
1155
1149
input_is_embed = torch .isin (
1156
1150
input_ids ,
@@ -1165,6 +1159,17 @@ def __call__(
1165
1159
embed_is_patch = embed_ids == self .image_patch_id
1166
1160
assert embed_is_patch .sum () == feat_is_patch .sum ()
1167
1161
1162
+ # image_tokens = extra_joint + joint
1163
+ # Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
1164
+ embed_start = torch .nonzero (embed_ids == self .im_start_id )[::2 , 0 ]
1165
+ embed_end = torch .nonzero (embed_ids == self .im_end_id )[1 ::2 , 0 ]
1166
+ assert len (embed_start ) == len (embed_end ) == len (images )
1167
+
1168
+ embed_is_patch = [
1169
+ embed_is_patch [start :end + 1 ]
1170
+ for start , end in zip (embed_start , embed_end )
1171
+ ]
1172
+
1168
1173
tilings = [
1169
1174
self .select_tiling (
1170
1175
image_width = image .size [0 ],
@@ -1180,7 +1185,7 @@ def __call__(
1180
1185
outputs ["num_crops" ] = num_crops
1181
1186
outputs ["img_patch_id" ] = self .image_patch_id
1182
1187
1183
- return BatchFeature (outputs , tensor_type = return_tensors )
1188
+ return BatchFeature (outputs )
1184
1189
1185
1190
1186
1191
class MolmoProcessingInfo (BaseProcessingInfo ):
@@ -1190,9 +1195,7 @@ def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
1190
1195
return MolmoProcessorWrapper (processor )
1191
1196
1192
1197
def get_supported_mm_limits (self ) -> Mapping [str , Optional [int ]]:
1193
- # TODO: Investigate different `embed_is_patch` between cache/no-cache
1194
- # in multi-image case
1195
- return {"image" : 1 }
1198
+ return {"image" : None }
1196
1199
1197
1200
def get_mm_max_tokens_per_item (
1198
1201
self ,
@@ -1325,7 +1328,7 @@ def _get_mm_fields_config(
1325
1328
"image" , num_crops ),
1326
1329
feat_is_patch = MultiModalFieldConfig .flat_from_sizes (
1327
1330
"image" , num_crops ),
1328
- embed_is_patch = MultiModalFieldConfig .shared ("image" , num_images ),
1331
+ embed_is_patch = MultiModalFieldConfig .batched ("image" ),
1329
1332
num_crops = MultiModalFieldConfig .batched ("image" ),
1330
1333
img_patch_id = MultiModalFieldConfig .shared ("image" , num_images ),
1331
1334
)
@@ -1499,7 +1502,7 @@ def _parse_and_validate_image_input(
1499
1502
def _process_image_input (
1500
1503
self ,
1501
1504
image_input : MolmoImageInputs ,
1502
- ) -> Union [torch .Tensor , List [torch .Tensor ]]:
1505
+ ) -> Union [torch .Tensor , list [torch .Tensor ]]:
1503
1506
if isinstance (image_input ["images" ], list ):
1504
1507
# Call the vision backbone on the whole batch at once
1505
1508
images_flat = flatten_bn (image_input ["images" ], concat = True )
@@ -1530,7 +1533,7 @@ def _get_mm_embeds(
1530
1533
feat_is_patch : torch .Tensor , # Shape: (num_crop, num_patch)
1531
1534
num_crops : torch .Tensor , # Shape: (num_images,)
1532
1535
embed_is_patch : torch .Tensor , # Shape: (num_embeds,)
1533
- ) -> list [torch .Tensor ]:
1536
+ ) -> tuple [torch .Tensor , ... ]:
1534
1537
"""
1535
1538
Scatter the patch features into a contiguous tensor that corresponds
1536
1539
to the embedding tokens defined by the multimodal processor.
@@ -1565,16 +1568,12 @@ def _get_mm_embeds(
1565
1568
feats_per_image = features .split (num_crops_per_image )
1566
1569
f_is_patch_per_image = feat_is_patch .split (num_crops_per_image )
1567
1570
1568
- _ , _ , embed_dim = features .shape
1569
- (num_embeds , ) = embed_is_patch .shape
1570
-
1571
- embeds_in_batch = list [torch .Tensor ]()
1572
- for feats , f_is_patch in zip (feats_per_image , f_is_patch_per_image ):
1573
- embeds = feats .new_full ((num_embeds , embed_dim ), torch .nan )
1574
- embeds [embed_is_patch ] = feats [f_is_patch ]
1575
- embeds_in_batch .append (embeds )
1571
+ features = torch .cat ([
1572
+ feats [f_is_patch ]
1573
+ for feats , f_is_patch in zip (feats_per_image , f_is_patch_per_image )
1574
+ ])
1576
1575
1577
- return embeds_in_batch
1576
+ return scatter_patch_features ( features , embed_is_patch )
1578
1577
1579
1578
def get_multimodal_embeddings (
1580
1579
self , ** kwargs : object ) -> Optional [MultiModalEmbeddings ]:
0 commit comments