Skip to content

Commit ef9b565

Browse files
authored
【开源实习】blip_2模型微调 (#1965)
1 parent d48deaf commit ef9b565

File tree

5 files changed

+325
-0
lines changed

5 files changed

+325
-0
lines changed

llm/finetune/blip2/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# FineTune Blip2-opt-2.7b with Food500Cap
2+
- [base model: blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b)
3+
- [dataset: Food500Cap](https://huggingface.co/datasets/advancedcv/Food500Cap)
4+
由于资源有限,训练集和测试集各取原来的1/8,即dataset = dataset.select(range(0, len(dataset), 8))
5+
6+
## code
7+
- blip2_finetune_with_Food500Cap.py: mindspore下训练代码
8+
- image_caption_eval.py: 评估脚本
9+
10+
## requirements
11+
### mindspore
12+
- Ascend 910B
13+
- Python 3.9
14+
- MindSpore 2.3.1
15+
- MindNLP 0.4.1 (需要合入[mindnlp PR1958](https://github.com/mindspore-lab/mindnlp/pull/1958))
16+
- tokenizers>=0.21.0 datasets
17+
18+
### pytorch
19+
- GPU V100
20+
- CUDA 11.8
21+
- Python 3.10
22+
- Pytorch 2.1.0
23+
- Transformers 4.45.2
24+
- tokenizers>=0.21.0 datasets accelerate
25+
26+
## train loss
27+
![](./images/train_loss.png)
28+
29+
## image caption eval
30+
![](./images/image_caption_eval.png)
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from mindnlp.transformers import Blip2ForConditionalGeneration, Blip2Processor
2+
from mindnlp.core.optim import AdamW
3+
from mindnlp.core import value_and_grad
4+
5+
import mindspore as ms
6+
from mindspore.dataset import GeneratorDataset
7+
8+
from datasets import load_dataset
9+
import numpy as np
10+
from tqdm import tqdm
11+
import json
12+
13+
14+
def freeze_blip2_backbone(model, freeze_vit=True):
15+
"""
16+
Freeze the backbone of the blip2-opt model.
17+
If freeze_vit is True, freeze the vision model, including embeddings and encoder.
18+
The Language Model is always frozen.
19+
blip2-opt model architecture:
20+
{
21+
"query_tokens": {},
22+
"vision_model": {
23+
"embeddings": {},
24+
"encoder": {},
25+
"post_layernorm": {},
26+
},
27+
"qformer": {},
28+
"language_projection": {},
29+
"language_model": {}
30+
}
31+
"""
32+
if freeze_vit:
33+
for param in model.vision_model.embeddings.parameters():
34+
param.requires_grad = False
35+
for param in model.vision_model.encoder.parameters():
36+
param.requires_grad = False
37+
else:
38+
for param in model.vision_model.parameters():
39+
param.requires_grad = True
40+
41+
for param in model.language_model.parameters():
42+
param.requires_grad = False
43+
44+
return model
45+
46+
class ImageCaptioningDataset():
47+
def __init__(self, dataset, processor):
48+
self.dataset = dataset
49+
self.processor = processor
50+
51+
def __len__(self):
52+
return len(self.dataset)
53+
54+
def __getitem__(self, idx):
55+
if not isinstance(idx, int):
56+
idx = int(idx)
57+
item = self.dataset[idx]
58+
encoding = self.processor(images=item['image'], text=item['caption'], max_length=96, padding="max_length")
59+
return np.asarray(encoding["pixel_values"]).squeeze(0), np.asarray(encoding["input_ids"]), np.asarray(encoding["attention_mask"])
60+
61+
def get_loader(dataset, processor, batch_size, shuffle=True, num_workers=1, drop_remainder=True):
62+
dataset = ImageCaptioningDataset(dataset, processor)
63+
return GeneratorDataset(source=dataset,
64+
column_names=["pixel_values", "input_ids", "attention_mask"],
65+
shuffle=shuffle,
66+
num_parallel_workers=num_workers
67+
).batch(batch_size=batch_size,
68+
drop_remainder=drop_remainder)
69+
70+
class Trainer:
71+
def __init__(self, net, processor, optimizer,
72+
train_dataset, eval_dataset=None, save_path=None
73+
):
74+
self.net = net
75+
self.processor = processor
76+
self.opt = optimizer
77+
self.train_dataset = train_dataset
78+
self.weights = self.net.trainable_params()
79+
self.value_and_grad = value_and_grad(fn=self.forward_fn, params_or_argnums=self.weights)
80+
self.run_eval = eval_dataset is not None
81+
self.save_path = save_path
82+
if self.run_eval:
83+
self.eval_dataset = eval_dataset
84+
self.testdatasetRES_list = []
85+
86+
def forward_fn(self, input_ids, pixel_values, attention_mask):
87+
outputs = self.net(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=input_ids)
88+
loss = outputs.loss
89+
return loss
90+
91+
def train_single(self, input_ids, pixel_values, attention_mask):
92+
self.opt.zero_grad()
93+
loss = self.value_and_grad(input_ids, pixel_values, attention_mask)
94+
self.opt.step()
95+
return loss
96+
97+
def train(self, epochs):
98+
99+
best_val_loss = float('inf')
100+
101+
for epoch in range(0, epochs):
102+
print("\nEpoch {}/{}".format(epoch+1, epochs))
103+
self.net.set_train(True)
104+
tloss = 0
105+
step = 0
106+
for batch in tqdm(self.train_dataset.create_dict_iterator(), desc='training...'):
107+
input_ids = batch["input_ids"]
108+
pixel_values = batch["pixel_values"]
109+
attention_mask = batch["attention_mask"]
110+
111+
loss = self.train_single(input_ids, pixel_values, attention_mask)
112+
113+
tloss = tloss + loss.asnumpy()
114+
step = step + 1
115+
116+
tloss /= step
117+
print("\tTrain Loss {:.04f}".format(tloss))
118+
119+
if self.run_eval:
120+
self.net.set_train(False)
121+
val_loss, testdatasetRES = self.eval()
122+
self.testdatasetRES_list.append(testdatasetRES)
123+
print("Epoch {} complete! Validation Loss : {}".format(epoch + 1, val_loss))
124+
if val_loss < best_val_loss:
125+
print("Best validation Loss improved from {} to {}".format(best_val_loss, val_loss))
126+
best_val_loss = val_loss
127+
if self.save_path is not None:
128+
print("saving model...")
129+
self.net.save_pretrained(self.save_path + '/best_model')
130+
131+
def eval(self):
132+
vloss = 0
133+
step = 0
134+
test_dataset_generated_text = []
135+
with ms._no_grad():
136+
for batch in tqdm(self.eval_dataset.create_dict_iterator(), desc='generating image captions on test dataset'):
137+
input_ids = batch["input_ids"]
138+
pixel_values = batch["pixel_values"]
139+
attention_mask = batch["attention_mask"]
140+
141+
generated_ids = self.net.generate(pixel_values)
142+
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
143+
test_dataset_generated_text.extend(generated_text)
144+
145+
outputs = self.net(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=input_ids)
146+
loss = outputs.loss
147+
148+
vloss = vloss + loss.asnumpy()
149+
step = step + 1
150+
testdatasetRES = {
151+
'annotations': [{'image_id': i, 'caption': text} for i, text in enumerate(test_dataset_generated_text)]
152+
}
153+
154+
return vloss / step, testdatasetRES
155+
156+
# 加载模型并设置可训练参数
157+
ms.set_context(device_target='Ascend', device_id=0, pynative_synchronize=True)
158+
processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
159+
model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
160+
model = freeze_blip2_backbone(model, freeze_vit=True)
161+
all_params = sum(p.size for p in model.parameters())
162+
trainable_params = sum(p.size for p in model.trainable_params())
163+
print(f'trainable params ratio = {trainable_params / all_params}')
164+
# 加载数据
165+
dataset = load_dataset('advancedcv/Food500Cap')
166+
# 受资源限制,取子集进行训练
167+
train_dataset = dataset['train']
168+
train_dataset = train_dataset.select(range(0, len(train_dataset), 8))
169+
test_dataset = dataset['test']
170+
test_dataset = test_dataset.select(range(0, len(test_dataset), 8))
171+
train_loader = get_loader(train_dataset, processor, batch_size=8, shuffle=True, drop_remainder=True)
172+
test_loader = get_loader(test_dataset, processor, batch_size=32, shuffle=False, drop_remainder=False)
173+
testdatasetGTS = {
174+
'annotations': [{'image_id': i, 'caption': item['caption']} for i, item in enumerate(test_dataset)]
175+
}
176+
# 训练
177+
optimizer = AdamW(model.trainable_params(), lr=5e-5)
178+
trainer = Trainer(net=model, processor=processor, optimizer=optimizer, train_dataset=train_loader, eval_dataset=test_loader, save_path='./trainer_output')
179+
trainer.train(10)
180+
if trainer.run_eval:
181+
save_generated_text = {
182+
"testdatasetGTS": testdatasetGTS,
183+
"testdatasetRES_list": trainer.testdatasetRES_list
184+
}
185+
with open("./testdataset_generated_text.json", 'w', encoding='utf-8') as f:
186+
json.dump(save_generated_text, f, ensure_ascii=False)
187+
# 评估
188+
# 评估所需环境在昇腾设备上似乎不支持,故需保存结果后换设备单独运行,对应脚本文件为image_caption_eval.py
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# image caption evaluation
2+
# required packages: pycocoevalcap : pip install pycocoevalcap
3+
# required packages: java : conda install -c conda-forge openjdk=11 -y
4+
5+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
6+
from pycocoevalcap.bleu.bleu import Bleu
7+
from pycocoevalcap.meteor.meteor import Meteor
8+
from pycocoevalcap.rouge.rouge import Rouge
9+
from pycocoevalcap.cider.cider import Cider
10+
from pycocoevalcap.spice.spice import Spice
11+
12+
13+
class COCOEvalCap:
14+
def __init__(self, images, gts, res):
15+
self.evalImgs = []
16+
self.eval = {}
17+
self.imgToEval = {}
18+
self.params = {'image_id': images}
19+
self.gts = gts
20+
self.res = res
21+
22+
def evaluate(self):
23+
imgIds = self.params['image_id']
24+
gts = self.gts
25+
res = self.res
26+
27+
# =================================================
28+
# Set up scorers
29+
# =================================================
30+
tokenizer = PTBTokenizer()
31+
gts = tokenizer.tokenize(gts)
32+
res = tokenizer.tokenize(res)
33+
34+
# =================================================
35+
# Set up scorers
36+
# =================================================
37+
scorers = [
38+
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
39+
(Meteor(), "METEOR"),
40+
(Rouge(), "ROUGE_L"),
41+
(Cider(), "CIDEr"),
42+
(Spice(), "SPICE"),
43+
]
44+
45+
# =================================================
46+
# Compute scores
47+
# =================================================
48+
eval = {}
49+
for scorer, method in scorers:
50+
score, scores = scorer.compute_score(gts, res)
51+
if type(method) == list:
52+
for sc, scs, m in zip(score, scores, method):
53+
self.setEval(sc, m)
54+
self.setImgToEvalImgs(scs, imgIds, m)
55+
else:
56+
self.setEval(score, method)
57+
self.setImgToEvalImgs(scores, imgIds, method)
58+
self.setEvalImgs()
59+
60+
def setEval(self, score, method):
61+
self.eval[method] = score
62+
63+
def setImgToEvalImgs(self, scores, imgIds, method):
64+
for imgId, score in zip(imgIds, scores):
65+
if not imgId in self.imgToEval:
66+
self.imgToEval[imgId] = {}
67+
self.imgToEval[imgId]["image_id"] = imgId
68+
self.imgToEval[imgId][method] = score
69+
70+
def setEvalImgs(self):
71+
self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
72+
73+
def calculate_metrics(rng, datasetGTS, datasetRES):
74+
imgIds = rng
75+
gts = {}
76+
res = {}
77+
78+
imgToAnnsGTS = {ann['image_id']: [] for ann in datasetGTS['annotations']}
79+
for ann in datasetGTS['annotations']:
80+
imgToAnnsGTS[ann['image_id']] += [ann]
81+
82+
imgToAnnsRES = {ann['image_id']: [] for ann in datasetRES['annotations']}
83+
for ann in datasetRES['annotations']:
84+
imgToAnnsRES[ann['image_id']] += [ann]
85+
86+
for imgId in imgIds:
87+
gts[imgId] = imgToAnnsGTS[imgId]
88+
res[imgId] = imgToAnnsRES[imgId]
89+
90+
evalObj = COCOEvalCap(imgIds, gts, res)
91+
evalObj.evaluate()
92+
return evalObj.eval
93+
94+
if __name__ == '__main__':
95+
rng = range(2)
96+
datasetGTS = {
97+
'annotations': [{u'image_id': 0, u'caption': u'the man is playing a guitar'},
98+
{u'image_id': 0, u'caption': u'a man is playing a guitar'},
99+
{u'image_id': 1, u'caption': u'a woman is slicing cucumbers'},
100+
{u'image_id': 1, u'caption': u'the woman is slicing cucumbers'},
101+
{u'image_id': 1, u'caption': u'a woman is cutting cucumbers'}]
102+
}
103+
datasetRES = {
104+
'annotations': [{u'image_id': 0, u'caption': u'man is playing guitar'},
105+
{u'image_id': 1, u'caption': u'a woman is cutting vegetables'}]
106+
}
107+
print(calculate_metrics(rng, datasetGTS, datasetRES))
Loading
25.7 KB
Loading

0 commit comments

Comments
 (0)