@@ -821,7 +821,17 @@ def forward(
821
821
positions : torch .Tensor ,
822
822
intermediate_tensors : Optional [IntermediateTensors ] = None ,
823
823
inputs_embeds : Optional [torch .Tensor ] = None ,
824
+ ** kwargs : object ,
824
825
) -> Union [torch .Tensor , IntermediateTensors ]:
826
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
827
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
828
+ # condition is only for v0 compatibility.
829
+ if inputs_embeds is None :
830
+ multimodal_embeds = self .get_multimodal_embeddings (** kwargs )
831
+ if multimodal_embeds is not None :
832
+ inputs_embeds = self .get_input_embeddings (input_ids , multimodal_embeds )
833
+ input_ids = None
834
+
825
835
model_output = self .model (input_ids , positions , intermediate_tensors ,
826
836
inputs_embeds )
827
837
return model_output
@@ -850,11 +860,11 @@ def get_multimodal_embeddings(self, **kwargs):
850
860
pixel_values = pixel_values if pixel_values is not None else kwargs .pop (
851
861
"image_patches" , None )
852
862
image_embeds = kwargs .pop ("image_embeds" , None )
853
- num_image_patches = kwargs .pop ("num_image_patches" )
854
863
855
864
if pixel_values is None and image_embeds is None :
856
865
return None
857
866
867
+ num_image_patches = kwargs .pop ("num_image_patches" )
858
868
if pixel_values is not None :
859
869
if isinstance (pixel_values , torch .Tensor ):
860
870
pixel_values = pixel_values .flatten (0 , 1 ).to (self .dtype )
0 commit comments