@@ -2881,31 +2881,45 @@ def generate(
2881
2881
* language_model_inputs .shape [:- 1 ], dtype = mindspore .int64
2882
2882
)
2883
2883
if input_ids is None :
2884
- input_ids = (
2885
- mindspore .Tensor ([[self .config .text_config .bos_token_id ]])
2886
- .repeat (batch_size , 1 )
2887
- )
2884
+ start_tokens = [self .config .text_config .bos_token_id ]
2885
+ if getattr (self .config , "image_token_index" , None ) is not None :
2886
+ start_tokens = [self .config .image_token_index ] * self .config .num_query_tokens + start_tokens
2887
+ input_ids = ops .tile (mindspore .Tensor ([start_tokens ]), (batch_size , 1 ))
2888
+
2889
+ inputs_embeds = self .get_input_embeddings ()(input_ids )
2888
2890
if attention_mask is None :
2889
2891
attention_mask = ops .ones_like (input_ids )
2890
- attention_mask = ops .cat ([language_attention_mask , attention_mask ], dim = 1 )
2891
2892
2892
- # concatenate query embeddings with prompt embeddings
2893
- inputs_embeds = self .get_input_embeddings ()(input_ids )
2894
- inputs_embeds = ops .cat ([language_model_inputs , inputs_embeds ], dim = 1 )
2893
+ # if the model already has "image_token_index" then the input is expanded to account for image embeds
2894
+ # otherwise we expand manually by concatenating
2895
+ if getattr (self .config , "image_token_index" , None ) is not None :
2896
+ special_image_mask = (input_ids == self .config .image_token_index ).unsqueeze (- 1 ).expand_as (inputs_embeds )
2897
+ inputs_embeds [special_image_mask ] = language_model_inputs .flatten ()
2898
+ else :
2899
+ logger .warning_once (
2900
+ "Expanding inputs for image tokens in BLIP-2 should be done in processing. "
2901
+ "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
2902
+ "Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
2903
+ )
2904
+ inputs_embeds = ops .cat ([language_model_inputs , inputs_embeds ], dim = 1 )
2905
+ attention_mask = ops .cat (
2906
+ [language_attention_mask , attention_mask ], dim = 1
2907
+ )
2895
2908
2896
- # add image_embeds length to max_length, so that the final max_length in counted only on token embeds
2897
- # -1 is to account for the prepended BOS after `generate.`
2898
- # TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
2899
- if not self .language_model .config .is_encoder_decoder :
2900
- generate_kwargs ["max_length" ] = generate_kwargs .get ("max_length" , 20 ) + language_model_inputs .shape [1 ] - 1
2901
- generate_kwargs ["min_length" ] = generate_kwargs .get ("min_length" , 0 ) + language_model_inputs .shape [1 ]
2909
+ # add image_embeds length to max_length, so that the final max_length in counted only on token embeds
2910
+ # -1 is to account for the prepended BOS after `generate.`
2911
+ # TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
2912
+ if not self .language_model .config .is_encoder_decoder :
2913
+ generate_kwargs ["max_length" ] = (
2914
+ generate_kwargs .get ("max_length" , 20 ) + language_model_inputs .shape [1 ] - 1
2915
+ )
2916
+ generate_kwargs ["min_length" ] = generate_kwargs .get ("min_length" , 0 ) + language_model_inputs .shape [1 ]
2902
2917
2903
- outputs = self .language_model .generate (
2904
- inputs_embeds = inputs_embeds ,
2905
- attention_mask = attention_mask ,
2906
- ** generate_kwargs ,
2907
- )
2918
+ inputs = {"inputs_embeds" : inputs_embeds , "attention_mask" : attention_mask }
2919
+ if not self .language_model .config .is_encoder_decoder :
2920
+ inputs ["input_ids" ] = input_ids
2908
2921
2922
+ outputs = self .language_model .generate (** inputs , ** generate_kwargs )
2909
2923
return outputs
2910
2924
2911
2925
__all__ = [
0 commit comments