1
+ from mindnlp .transformers import Blip2ForConditionalGeneration , Blip2Processor
2
+ from mindnlp .core .optim import AdamW
3
+ from mindnlp .core import value_and_grad
4
+
5
+ import mindspore as ms
6
+ from mindspore .dataset import GeneratorDataset
7
+
8
+ from datasets import load_dataset
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import json
12
+
13
+
14
+ def freeze_blip2_backbone (model , freeze_vit = True ):
15
+ """
16
+ Freeze the backbone of the blip2-opt model.
17
+ If freeze_vit is True, freeze the vision model, including embeddings and encoder.
18
+ The Language Model is always frozen.
19
+ blip2-opt model architecture:
20
+ {
21
+ "query_tokens": {},
22
+ "vision_model": {
23
+ "embeddings": {},
24
+ "encoder": {},
25
+ "post_layernorm": {},
26
+ },
27
+ "qformer": {},
28
+ "language_projection": {},
29
+ "language_model": {}
30
+ }
31
+ """
32
+ if freeze_vit :
33
+ for param in model .vision_model .embeddings .parameters ():
34
+ param .requires_grad = False
35
+ for param in model .vision_model .encoder .parameters ():
36
+ param .requires_grad = False
37
+ else :
38
+ for param in model .vision_model .parameters ():
39
+ param .requires_grad = True
40
+
41
+ for param in model .language_model .parameters ():
42
+ param .requires_grad = False
43
+
44
+ return model
45
+
46
+ class ImageCaptioningDataset ():
47
+ def __init__ (self , dataset , processor ):
48
+ self .dataset = dataset
49
+ self .processor = processor
50
+
51
+ def __len__ (self ):
52
+ return len (self .dataset )
53
+
54
+ def __getitem__ (self , idx ):
55
+ if not isinstance (idx , int ):
56
+ idx = int (idx )
57
+ item = self .dataset [idx ]
58
+ encoding = self .processor (images = item ['image' ], text = item ['caption' ], max_length = 96 , padding = "max_length" )
59
+ return np .asarray (encoding ["pixel_values" ]).squeeze (0 ), np .asarray (encoding ["input_ids" ]), np .asarray (encoding ["attention_mask" ])
60
+
61
+ def get_loader (dataset , processor , batch_size , shuffle = True , num_workers = 1 , drop_remainder = True ):
62
+ dataset = ImageCaptioningDataset (dataset , processor )
63
+ return GeneratorDataset (source = dataset ,
64
+ column_names = ["pixel_values" , "input_ids" , "attention_mask" ],
65
+ shuffle = shuffle ,
66
+ num_parallel_workers = num_workers
67
+ ).batch (batch_size = batch_size ,
68
+ drop_remainder = drop_remainder )
69
+
70
+ class Trainer :
71
+ def __init__ (self , net , processor , optimizer ,
72
+ train_dataset , eval_dataset = None , save_path = None
73
+ ):
74
+ self .net = net
75
+ self .processor = processor
76
+ self .opt = optimizer
77
+ self .train_dataset = train_dataset
78
+ self .weights = self .net .trainable_params ()
79
+ self .value_and_grad = value_and_grad (fn = self .forward_fn , params_or_argnums = self .weights )
80
+ self .run_eval = eval_dataset is not None
81
+ self .save_path = save_path
82
+ if self .run_eval :
83
+ self .eval_dataset = eval_dataset
84
+ self .testdatasetRES_list = []
85
+
86
+ def forward_fn (self , input_ids , pixel_values , attention_mask ):
87
+ outputs = self .net (input_ids = input_ids , pixel_values = pixel_values , attention_mask = attention_mask , labels = input_ids )
88
+ loss = outputs .loss
89
+ return loss
90
+
91
+ def train_single (self , input_ids , pixel_values , attention_mask ):
92
+ self .opt .zero_grad ()
93
+ loss = self .value_and_grad (input_ids , pixel_values , attention_mask )
94
+ self .opt .step ()
95
+ return loss
96
+
97
+ def train (self , epochs ):
98
+
99
+ best_val_loss = float ('inf' )
100
+
101
+ for epoch in range (0 , epochs ):
102
+ print ("\n Epoch {}/{}" .format (epoch + 1 , epochs ))
103
+ self .net .set_train (True )
104
+ tloss = 0
105
+ step = 0
106
+ for batch in tqdm (self .train_dataset .create_dict_iterator (), desc = 'training...' ):
107
+ input_ids = batch ["input_ids" ]
108
+ pixel_values = batch ["pixel_values" ]
109
+ attention_mask = batch ["attention_mask" ]
110
+
111
+ loss = self .train_single (input_ids , pixel_values , attention_mask )
112
+
113
+ tloss = tloss + loss .asnumpy ()
114
+ step = step + 1
115
+
116
+ tloss /= step
117
+ print ("\t Train Loss {:.04f}" .format (tloss ))
118
+
119
+ if self .run_eval :
120
+ self .net .set_train (False )
121
+ val_loss , testdatasetRES = self .eval ()
122
+ self .testdatasetRES_list .append (testdatasetRES )
123
+ print ("Epoch {} complete! Validation Loss : {}" .format (epoch + 1 , val_loss ))
124
+ if val_loss < best_val_loss :
125
+ print ("Best validation Loss improved from {} to {}" .format (best_val_loss , val_loss ))
126
+ best_val_loss = val_loss
127
+ if self .save_path is not None :
128
+ print ("saving model..." )
129
+ self .net .save_pretrained (self .save_path + '/best_model' )
130
+
131
+ def eval (self ):
132
+ vloss = 0
133
+ step = 0
134
+ test_dataset_generated_text = []
135
+ with ms ._no_grad ():
136
+ for batch in tqdm (self .eval_dataset .create_dict_iterator (), desc = 'generating image captions on test dataset' ):
137
+ input_ids = batch ["input_ids" ]
138
+ pixel_values = batch ["pixel_values" ]
139
+ attention_mask = batch ["attention_mask" ]
140
+
141
+ generated_ids = self .net .generate (pixel_values )
142
+ generated_text = self .processor .batch_decode (generated_ids , skip_special_tokens = True )
143
+ test_dataset_generated_text .extend (generated_text )
144
+
145
+ outputs = self .net (input_ids = input_ids , pixel_values = pixel_values , attention_mask = attention_mask , labels = input_ids )
146
+ loss = outputs .loss
147
+
148
+ vloss = vloss + loss .asnumpy ()
149
+ step = step + 1
150
+ testdatasetRES = {
151
+ 'annotations' : [{'image_id' : i , 'caption' : text } for i , text in enumerate (test_dataset_generated_text )]
152
+ }
153
+
154
+ return vloss / step , testdatasetRES
155
+
156
+ # 加载模型并设置可训练参数
157
+ ms .set_context (device_target = 'Ascend' , device_id = 0 , pynative_synchronize = True )
158
+ processor = Blip2Processor .from_pretrained ('Salesforce/blip2-opt-2.7b' )
159
+ model = Blip2ForConditionalGeneration .from_pretrained ('Salesforce/blip2-opt-2.7b' )
160
+ model = freeze_blip2_backbone (model , freeze_vit = True )
161
+ all_params = sum (p .size for p in model .parameters ())
162
+ trainable_params = sum (p .size for p in model .trainable_params ())
163
+ print (f'trainable params ratio = { trainable_params / all_params } ' )
164
+ # 加载数据
165
+ dataset = load_dataset ('advancedcv/Food500Cap' )
166
+ # 受资源限制,取子集进行训练
167
+ train_dataset = dataset ['train' ]
168
+ train_dataset = train_dataset .select (range (0 , len (train_dataset ), 8 ))
169
+ test_dataset = dataset ['test' ]
170
+ test_dataset = test_dataset .select (range (0 , len (test_dataset ), 8 ))
171
+ train_loader = get_loader (train_dataset , processor , batch_size = 8 , shuffle = True , drop_remainder = True )
172
+ test_loader = get_loader (test_dataset , processor , batch_size = 32 , shuffle = False , drop_remainder = False )
173
+ testdatasetGTS = {
174
+ 'annotations' : [{'image_id' : i , 'caption' : item ['caption' ]} for i , item in enumerate (test_dataset )]
175
+ }
176
+ # 训练
177
+ optimizer = AdamW (model .trainable_params (), lr = 5e-5 )
178
+ trainer = Trainer (net = model , processor = processor , optimizer = optimizer , train_dataset = train_loader , eval_dataset = test_loader , save_path = './trainer_output' )
179
+ trainer .train (10 )
180
+ if trainer .run_eval :
181
+ save_generated_text = {
182
+ "testdatasetGTS" : testdatasetGTS ,
183
+ "testdatasetRES_list" : trainer .testdatasetRES_list
184
+ }
185
+ with open ("./testdataset_generated_text.json" , 'w' , encoding = 'utf-8' ) as f :
186
+ json .dump (save_generated_text , f , ensure_ascii = False )
187
+ # 评估
188
+ # 评估所需环境在昇腾设备上似乎不支持,故需保存结果后换设备单独运行,对应脚本文件为image_caption_eval.py
0 commit comments