@@ -174,10 +174,10 @@ def get_hf_config(self):
174
174
return self .ctx .model_config .hf_config
175
175
176
176
def get_supported_mm_limits (self ):
177
- return {"image" : None , "video" : None }
177
+ return {"image" : None }
178
178
179
179
def get_mm_max_tokens_per_item (self , seq_len , mm_counts ):
180
- return {"image" : self .get_max_image_tokens (), "video" : 0 }
180
+ return {"image" : self .get_max_image_tokens ()}
181
181
182
182
def get_max_image_tokens (self ) -> int :
183
183
width , height = self .get_max_image_size ()
@@ -750,7 +750,6 @@ def load_weights(self, weights: Iterable[tuple[str,
750
750
MultiModalProcessor ,
751
751
info = MultiModalProcessingInfo ,
752
752
dummy_inputs = MultiModalDummyInputsBuilder )
753
- @support_torch_compile
754
753
class TransformersForMultimodalLM (nn .Module , SupportsQuant , SupportsLoRA ,
755
754
SupportsPP , SupportsMultiModal ):
756
755
embedding_padding_modules = ["lm_head" ]
@@ -857,12 +856,11 @@ def get_multimodal_embeddings(self, **kwargs):
857
856
if pixel_values is not None :
858
857
if isinstance (pixel_values , torch .Tensor ):
859
858
pixel_values = pixel_values .flatten (0 , 1 ).to (self .dtype )
860
- if isinstance (num_image_patches , list ):
861
- num_image_patches = torch .cat (num_image_patches )
862
- num_image_patches = num_image_patches .flatten ()
863
859
else :
864
860
pixel_values = torch .cat (pixel_values ).to (self .dtype )
865
- num_image_patches = torch .cat (num_image_patches ).flatten ()
861
+
862
+ if isinstance (num_image_patches , list ):
863
+ num_image_patches = torch .cat (num_image_patches )
866
864
867
865
vision_embeddings = self .model .model .get_image_features (
868
866
pixel_values ,
@@ -880,7 +878,7 @@ def get_multimodal_embeddings(self, **kwargs):
880
878
# but transformers returns concat tensors if each patch
881
879
# is of different size. We split it back to make vLLM happy
882
880
vision_embeddings = torch .split (vision_embeddings ,
883
- num_image_patches .tolist ())
881
+ num_image_patches .flatten (). tolist ())
884
882
vision_embeddings = [
885
883
embed .flatten (start_dim = 0 , end_dim = - 2 )
886
884
for embed in vision_embeddings
0 commit comments