Skip to content

Commit 8c26b18

Browse files
authored
fix Blip2加载和推理bug #1902 #1904 #1905 (#1958)
1 parent 3ea3f4a commit 8c26b18

File tree

3 files changed

+55
-23
lines changed

3 files changed

+55
-23
lines changed

mindnlp/transformers/generation/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,13 @@ def _prepare_generated_length(
12971297
and not self.config.is_encoder_decoder
12981298
):
12991299
generation_config.max_length -= inputs_tensor.shape[1]
1300+
# by default let's always generate 20 new tokens
1301+
elif has_default_max_length:
1302+
if generation_config.max_length == GenerationConfig().max_length:
1303+
generation_config.max_length = generation_config.max_length + input_ids_length
1304+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
1305+
if max_position_embeddings is not None:
1306+
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
13001307

13011308
# same for min length
13021309
if generation_config.min_new_tokens is not None:

mindnlp/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,31 +2881,45 @@ def generate(
28812881
*language_model_inputs.shape[:-1], dtype=mindspore.int64
28822882
)
28832883
if input_ids is None:
2884-
input_ids = (
2885-
mindspore.Tensor([[self.config.text_config.bos_token_id]])
2886-
.repeat(batch_size, 1)
2887-
)
2884+
start_tokens = [self.config.text_config.bos_token_id]
2885+
if getattr(self.config, "image_token_index", None) is not None:
2886+
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
2887+
input_ids = ops.tile(mindspore.Tensor([start_tokens]), (batch_size, 1))
2888+
2889+
inputs_embeds = self.get_input_embeddings()(input_ids)
28882890
if attention_mask is None:
28892891
attention_mask = ops.ones_like(input_ids)
2890-
attention_mask = ops.cat([language_attention_mask, attention_mask], dim=1)
28912892

2892-
# concatenate query embeddings with prompt embeddings
2893-
inputs_embeds = self.get_input_embeddings()(input_ids)
2894-
inputs_embeds = ops.cat([language_model_inputs, inputs_embeds], dim=1)
2893+
# if the model already has "image_token_index" then the input is expanded to account for image embeds
2894+
# otherwise we expand manually by concatenating
2895+
if getattr(self.config, "image_token_index", None) is not None:
2896+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
2897+
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
2898+
else:
2899+
logger.warning_once(
2900+
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
2901+
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
2902+
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
2903+
)
2904+
inputs_embeds = ops.cat([language_model_inputs, inputs_embeds], dim=1)
2905+
attention_mask = ops.cat(
2906+
[language_attention_mask, attention_mask], dim=1
2907+
)
28952908

2896-
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
2897-
# -1 is to account for the prepended BOS after `generate.`
2898-
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
2899-
if not self.language_model.config.is_encoder_decoder:
2900-
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
2901-
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
2909+
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
2910+
# -1 is to account for the prepended BOS after `generate.`
2911+
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
2912+
if not self.language_model.config.is_encoder_decoder:
2913+
generate_kwargs["max_length"] = (
2914+
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
2915+
)
2916+
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
29022917

2903-
outputs = self.language_model.generate(
2904-
inputs_embeds=inputs_embeds,
2905-
attention_mask=attention_mask,
2906-
**generate_kwargs,
2907-
)
2918+
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
2919+
if not self.language_model.config.is_encoder_decoder:
2920+
inputs["input_ids"] = input_ids
29082921

2922+
outputs = self.language_model.generate(**inputs, **generate_kwargs)
29092923
return outputs
29102924

29112925
__all__ = [

mindnlp/transformers/models/blip_2/processing_blip_2.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...image_utils import ImageInput
2222
from ...processing_utils import ProcessorMixin
23-
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
23+
from ...tokenization_utils_base import AddedToken, BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
2424
from ....utils import TensorType
2525

2626

@@ -36,13 +36,16 @@ class Blip2Processor(ProcessorMixin):
3636
An instance of [`BlipImageProcessor`]. The image processor is a required input.
3737
tokenizer (`AutoTokenizer`):
3838
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
39+
num_query_tokens (`int`, *optional*):
40+
Number of tokens used by the Qformer as queries, should be same as in model's config.
3941
"""
4042
attributes = ["image_processor", "tokenizer"]
43+
valid_kwargs = ["num_query_tokens"]
4144
image_processor_class = "BlipImageProcessor"
4245
tokenizer_class = "AutoTokenizer"
4346

4447
# Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__
45-
def __init__(self, image_processor, tokenizer):
48+
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
4649
"""
4750
Initializes a new instance of the Blip2Processor class.
4851
@@ -53,16 +56,24 @@ def __init__(self, image_processor, tokenizer):
5356
tokenizer: An object representing the tokenizer to be used.
5457
t should have the necessary methods and attributes required for tokenization.
5558
The 'return_token_type_ids' attribute of the tokenizer will be set to False.
56-
59+
num_query_tokens (`int`, *optional*):
60+
Number of tokens used by the Qformer as queries, should be same as in model's config.
5761
Returns:
5862
None.
5963
6064
Raises:
6165
None.
6266
"""
6367
tokenizer.return_token_type_ids = False
68+
self.current_processor = image_processor
69+
if not hasattr(tokenizer, "image_token"):
70+
self.image_token = AddedToken("<image>", normalized=False, special=True)
71+
tokenizer.add_tokens([self.image_token], special_tokens=True)
72+
else:
73+
self.image_token = tokenizer.image_token
74+
self.num_query_tokens = num_query_tokens
75+
6476
super().__init__(image_processor, tokenizer)
65-
self.current_processor = self.image_processor
6677

6778
# Copied from transformers.models.blip.processing_blip.BlipProcessor.__call__
6879
def __call__(

0 commit comments

Comments
 (0)