Skip to content

Commit 2a09101

Browse files
authored
【开源实习】blenderbot_small模型微调 (#1980)
1 parent 1455ae6 commit 2a09101

8 files changed

+1777
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# bigbird_pegasus模型微调对比
2+
## train loss
3+
4+
对比微调训练的loss变化
5+
6+
| epoch | mindnlp+mindspore | transformer+torch(4060) |transformer+torch(4060,another time) |
7+
| ----- | ----------------- | ------------------------- |------------------------- |
8+
| 1 | 2.0958 | 8.7301 |5.4650 |
9+
| 2 | 1.969 | 8.1557 |4.6890 |
10+
| 3 | 1.8755 | 7.7516 |4.2572 |
11+
| 4 | 1.8264 | 7.5017 |4.0263 |
12+
| 5 | 1.7349 | 7.2614 |3.9444 |
13+
| 6 | 1.678 | 7.0559 |3.8428 |
14+
| 7 | 1.6937 | 6.8405 |3.7187 |
15+
| 8 | 1.654 | 6.7297 |3.7192 |
16+
| 9 | 1.6365 | 6.7136 |3.5434 |
17+
| 10 | 1.7003 | 6.6279 |3.5881 |
18+
19+
## eval loss
20+
21+
对比评估得分
22+
23+
| epoch | mindnlp+mindspore | transformer+torch(4060) | transformer+torch(4060) |
24+
| ----- | ------------------ | ------------------------- |------------------------- |
25+
| 1 | 2.1257965564727783 | 6.3235931396484375 |4.264792442321777 |

llm/finetune/bigbird_pagesus/mindNLPDatatricksAuto.ipynb

Lines changed: 1095 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Blenderbot_Small的Synthetic-Persona-Chat微调
2+
3+
## 硬件
4+
5+
资源规格:NPU: 1*Ascend-D910B(显存: 64GB), CPU: 24, 内存: 192GB
6+
7+
智算中心:武汉智算中心
8+
9+
镜像:mindspore_2_5_py311_cann8
10+
11+
torch训练硬件资源规格:Nvidia 3090
12+
13+
## 模型与数据集
14+
15+
模型:"facebook/blenderbot_small-90M"
16+
17+
数据集:"google/Synthetic-Persona-Chat"
18+
19+
## 训练损失
20+
21+
| trainloss | mindspore+mindnlp | Pytorch+transformers |
22+
| --------- | ----------------- | -------------------- |
23+
| 1 | 0.1737 | 0.2615 |
24+
| 2 | 0.1336 | 0.1269 |
25+
| 3 | 0.1099 | 0.0987 |
26+
27+
## 评估损失
28+
29+
| eval loss | mindspore+mindnlp | Pytorch+transformers |
30+
| --------- | ------------------- | -------------------- |
31+
| 1 | 0.16312436759471893 | 0.160710409283638 |
32+
| 2 | 0.15773458778858185 | 0.15692724287509918 |
33+
| 3 | 0.15398454666137695 | 0.1593361645936966 |
34+
| 4 | 0.15398454666137695 | 0.1593361645936966 |
35+
36+
## 对话测试
37+
38+
* 问题输入:
39+
40+
Nice to meet you too. What are you interested in?
41+
42+
* mindnlp未微调前的回答:
43+
44+
i ' m not really sure . i ' ve always wanted to go back to school , but i don ' t know what i want to do yet .
45+
46+
* mindnlp微调后的回答:
47+
48+
user 2: i'm interested in a lot of things, but my main interests are music, art, and music. i also like to play video games, go to the movies, and spend time with my friends and family. my favorite video games are the legend of zelda series, and my favorite game is the witcher 3. name) what breath my his their i they ] include yes when philip boarity
49+
50+
* torch微调前的回答:
51+
i ' m not really sure . i ' ve always wanted to go back to school , but i don ' t know what i want to do yet .
52+
53+
* torch微调后的回答:
54+
55+
user 2: i ' m interested in a lot of things , but my favorite ones are probably history and language . what do you like to do for fun ? hades is one of my favorite characters . hades is also my favorite character . hades namegardenblem pola litz strönape ception ddie ppon plata yder foundry patel fton darted sler bbins vili atsu ović endra scoe barons
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Blenderbot_Small的coqa微调
2+
3+
## 硬件
4+
5+
资源规格:NPU: 1*Ascend-D910B(显存: 64GB), CPU: 24, 内存: 192GB
6+
7+
智算中心:武汉智算中心
8+
9+
镜像:mindspore_2_5_py311_cann8
10+
11+
torch训练硬件资源规格:Nvidia 3090
12+
13+
## 模型与数据集
14+
15+
模型:"facebook/blenderbot_small-90M"
16+
17+
数据集:"stanfordnlp/coqa"
18+
19+
## 训练损失
20+
21+
| trainloss | mindspore+mindnlp | Pytorch+transformers |
22+
| --------- | ----------------- | -------------------- |
23+
| 1 | 0.0117 | 0.3391 |
24+
| 2 | 0.0065 | 0.0069 |
25+
| 3 | 0.0041 | 0.0035 |
26+
| 4 | 0.0027 | |
27+
| 5 | 0.0017 | |
28+
| 6 | 0.0012 | |
29+
| 7 | 0.0007 | |
30+
| 8 | 0.0005 | |
31+
| 9 | 0.0003 | |
32+
| 10 | 0.0002 | |
33+
34+
## 评估损失
35+
36+
| eval loss | mindspore+mindnlp | Pytorch+transformers |
37+
| --------- | -------------------- | -------------------- |
38+
| 1 | 0.010459424927830696 | 0.010080045089125633 |
39+
| 2 | 0.010958473198115826 | 0.008667134679853916 |
40+
| 3 | 0.011061458848416805 | 0.00842051301151514 |
41+
| 4 | 0.011254088021814823 | 0.00842051301151514 |
42+
| 5 | 0.011891312897205353 | |
43+
| 6 | 0.012321822345256805 | |
44+
| 7 | 0.012598296627402306 | |
45+
| 8 | 0.01246054656803608 | |
46+
| 9 | 0.0124361552298069 | |
47+
| 10 | 0.01264810748398304 | |
48+
49+
## 对话测试
50+
51+
问题来自评估数据集的第一个问题,微调后看起来效果不太好。
52+
53+
* 问题输入:
54+
55+
The Vatican Apostolic Library, more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula.
56+
57+
The Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail.
58+
59+
In March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts, to be made available online.
60+
61+
The Vatican Secret Archives were separated from the library at the beginning of the 17th century; they contain another 150,000 items.
62+
63+
Scholars have traditionally divided the history of the library into five periods, Pre-Lateran, Lateran, Avignon, Pre-Vatican and Vatican.
64+
65+
The Pre-Lateran period, comprising the initial days of the library, dated from the earliest days of the Church. Only a handful of volumes survive from this period, though some are very significant.When was the Vat formally opened?
66+
67+
* mindnlp未微调前的回答:
68+
69+
wow , that ' s a lot of information ! i ' ll have to check it out !
70+
71+
* mindnlp微调后的回答:
72+
73+
it was formally established in 1475 remarked wang commenced baxter vii affiliate xii ) detained amid xvi scarcely spokesman murmured pradesh condemned himweekriedly upheld kilometers ywood longitude reportedly unarmed sworth congressional quarreandrea according monsieur constituent zhang smiled ɪfellows combe mitt
74+
75+
* torch微调前的回答:
76+
wow , that ' s a lot of information ! i ' ll have to check it out !
77+
78+
* torch微调后的回答:
79+
80+
1475 monsieur palermo pradesh ˈprincipality pali turbines constituent gallagher xii ɪxv odi pauline ɒgregory coefficient julien deutsche sbury roberto henrietta əenko militants gmina podium hya taliban hague ːkensington poole inmate livery habsburg longitude reid lieu@@
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from mindnlp.transformers import BlenderbotSmallForConditionalGeneration, BlenderbotSmallTokenizer
2+
from mindnlp.engine import Trainer, TrainingArguments
3+
from datasets import load_dataset, load_from_disk
4+
import mindspore as ms
5+
import os
6+
7+
# 设置运行模式和设备
8+
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")
9+
10+
# 设置 HF_ENDPOINT 环境变量
11+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
12+
# 加载模型和分词器
13+
print("加载模型和分词器")
14+
model_name = "facebook/blenderbot_small-90M"
15+
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_name)
16+
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_name)
17+
print("模型和分词器加载完成")
18+
# 测试原始模型的输出
19+
input = "The Vatican Apostolic Library, more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. \n\nThe Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. \n\nIn March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts, to be made available online. \n\nThe Vatican Secret Archives were separated from the library at the beginning of the 17th century; they contain another 150,000 items. \n\nScholars have traditionally divided the history of the library into five periods, Pre-Lateran, Lateran, Avignon, Pre-Vatican and Vatican. \n\nThe Pre-Lateran period, comprising the initial days of the library, dated from the earliest days of the Church. Only a handful of volumes survive from this period, though some are very significant.When was the Vat formally opened?"
20+
print("input question:", input)
21+
input_tokens = tokenizer([input], return_tensors="ms")
22+
output_tokens = model.generate(**input_tokens)
23+
print("output answer:", tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0])
24+
25+
# # 设置填充标记(BlenderbotSmall默认无pad_token)
26+
# # tokenizer.pad_token = tokenizer.eos_token # 用eos_token作为填充标记
27+
# # model.config.pad_token_id = tokenizer.eos_token_id
28+
29+
print("加载数据集")
30+
# 定义数据集保存路径
31+
dataset_path = "./dataset_valid_preprocessed"
32+
# 检查是否存在处理好的数据集
33+
if os.path.exists(dataset_path):
34+
# 加载预处理后的数据集
35+
dataset_train = load_from_disk("./dataset_train_preprocessed")
36+
dataset_valid = load_from_disk("./dataset_valid_preprocessed")
37+
else:
38+
dataset = load_dataset("stanfordnlp/coqa")
39+
print("dataset finished\n")
40+
print("dataset:", dataset)
41+
print("\ndataset[train][0]:", dataset["train"][0])
42+
print("\ndataset[validation][0]:", dataset["validation"][0])
43+
dataset_train = dataset["train"]
44+
dataset_valid = dataset["validation"]
45+
# 数据预处理,coqa数据集是一个sotry和多个问题和多个答案的数据集,这里只取出第一个问题和第一个答案,sotry和问题拼接作为模型的输入,第一个答案作为模型的输出
46+
def preprocess_function(examples):
47+
# 取出第一个问题的文本
48+
first_question = examples['questions'][0]
49+
# 取出第一个答案的文本
50+
first_answer = examples['answers']['input_text'][0]
51+
# 将故事和第一个问题拼接成模型的输入格式
52+
inputs = examples['story'] + " " + first_question
53+
# 删除多余的引号
54+
inputs = inputs.replace('"', '')
55+
# 将第一个答案作为模型的输出
56+
labels = first_answer
57+
# 删除多余的引号
58+
labels = labels.replace('"', '')
59+
return {'input_ids': inputs, 'labels': labels}
60+
61+
def tokenize_function(examples):
62+
# 对输入进行分词
63+
model_inputs = tokenizer(examples['input_ids'], max_length=512, truncation=True, padding="max_length")
64+
# 对标签进行分词
65+
with tokenizer.as_target_tokenizer():
66+
labels = tokenizer(examples['labels'], max_length=512, truncation=True, padding="max_length")
67+
model_inputs["labels"] = labels["input_ids"]
68+
return model_inputs
69+
# 应用预处理函数
70+
dataset_train = dataset_train.map(preprocess_function, batched=False)
71+
dataset_train = dataset_train.map(tokenize_function, batched=True)
72+
dataset_train = dataset_train.remove_columns(["source", "story", "questions", "answers"])
73+
74+
dataset_valid = dataset_valid.map(preprocess_function, batched=False)
75+
dataset_valid = dataset_valid.map(tokenize_function, batched=True)
76+
dataset_valid = dataset_valid.remove_columns(["source", "story", "questions", "answers"])
77+
78+
dataset_train.save_to_disk("./dataset_train_preprocessed")
79+
dataset_valid.save_to_disk("./dataset_valid_preprocessed")
80+
print("dataset_train_tokenizerd:", dataset_train)
81+
82+
print("转化为mindspore格式数据集")
83+
import numpy as np
84+
def data_generator(dataset):
85+
for item in dataset:
86+
yield (
87+
np.array(item["input_ids"], dtype=np.int32),
88+
np.array(item["attention_mask"], dtype=np.int32),
89+
np.array(item["labels"], dtype=np.int32)
90+
)
91+
import mindspore.dataset as ds
92+
def create_mindspore_dataset(dataset, shuffle=True):
93+
return ds.GeneratorDataset(
94+
source=lambda: data_generator(dataset), # 使用 lambda 包装生成器
95+
column_names=["input_ids", "attention_mask", "labels"],
96+
shuffle=shuffle,
97+
num_parallel_workers=1
98+
)
99+
dataset_train_tokenized = create_mindspore_dataset(dataset_train, shuffle=True)
100+
dataset_valid_tokenized = create_mindspore_dataset(dataset_valid, shuffle=False)
101+
102+
TOKENS = 20
103+
EPOCHS = 10
104+
BATCH_SIZE = 4
105+
training_args = TrainingArguments(
106+
output_dir='./MindNLPblenderbot_coqa_finetuned',
107+
overwrite_output_dir=True,
108+
num_train_epochs=EPOCHS,
109+
per_device_train_batch_size=BATCH_SIZE,
110+
per_device_eval_batch_size=BATCH_SIZE,
111+
save_steps=500, # Save checkpoint every 500 steps
112+
save_total_limit=2, # Keep only the last 2 checkpoints
113+
logging_dir="./mindsporelogs", # Directory for logs
114+
logging_steps=100, # Log every 100 steps
115+
logging_strategy="epoch",
116+
evaluation_strategy="epoch",
117+
eval_steps=500, # Evaluation frequency
118+
warmup_steps=100,
119+
learning_rate=5e-5,
120+
weight_decay=0.01, # Weight decay
121+
)
122+
123+
trainer = Trainer(
124+
model=model,
125+
args=training_args,
126+
train_dataset=dataset_train_tokenized,
127+
eval_dataset=dataset_valid_tokenized
128+
)
129+
# 开始训练
130+
print("开始训练")
131+
trainer.train()
132+
eval_results = trainer.evaluate()
133+
print(f"Evaluation results: {eval_results}")
134+
model.save_pretrained("./blenderbot_coqa_finetuned")
135+
tokenizer.save_pretrained("./blenderbot_coqa_finetuned")
136+
fine_tuned_model = BlenderbotSmallForConditionalGeneration.from_pretrained("./blenderbot_coqa_finetuned")
137+
fine_tuned_tokenizer = BlenderbotSmallTokenizer.from_pretrained("./blenderbot_coqa_finetuned")
138+
139+
140+
print("再次测试对话")
141+
input = "The Vatican Apostolic Library, more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. \n\nThe Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. \n\nIn March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts, to be made available online. \n\nThe Vatican Secret Archives were separated from the library at the beginning of the 17th century; they contain another 150,000 items. \n\nScholars have traditionally divided the history of the library into five periods, Pre-Lateran, Lateran, Avignon, Pre-Vatican and Vatican. \n\nThe Pre-Lateran period, comprising the initial days of the library, dated from the earliest days of the Church. Only a handful of volumes survive from this period, though some are very significant.When was the Vat formally opened?"
142+
print("input question:", input)
143+
input_tokens = fine_tuned_tokenizer([input], return_tensors="ms")
144+
output_tokens = fine_tuned_model.generate(**input_tokens)
145+
print("output answer:", fine_tuned_tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0])

0 commit comments

Comments
 (0)