Skip to content

Commit fe1a324

Browse files
开源实习 BEiT 模型微调 (#1975)
1 parent aaaaaaa commit fe1a324

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

llm/finetune/BEiT/BEiT finetune.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Finetune microsoft beit-base-patch16-224 model
2+
- base model: [microsoft beit-base-patch16-224](https://huggingface.co/microsoft/beit-base-patch16-224)
3+
- dataset: [cifar10](https://huggingface.co/datasets/uoft-cs/cifar10)
4+
- pytorch version finetune [github](https://github.com/4everImmortality/microsoft-beit-cifar10-finetune)
5+
# requirments
6+
## pytorch
7+
- GPU: RTX 4070ti 12G
8+
- cuda: 11.8
9+
- Python version: 3.10
10+
- torch version: 2.5.0
11+
- transformers version : 4.47.0
12+
## mindspore
13+
- Ascend: 910B
14+
- python: 3.9
15+
- mindspore: 2.3.1
16+
- mindnlp: 0.4.0
17+
# Result for finetune
18+
training for 3 epochs
19+
## torch
20+
| Epoch | eval_loss | eval_accuracy |
21+
|-------|-----------|--------------|
22+
| 1 | 0.193 | 94.4% |
23+
| 2 | 0.157 | 95.4% |
24+
| 3 | 0.117 | 96.2% |
25+
## mindspore
26+
| Epoch | eval_loss | eval_accuracy |
27+
|-------|-----------|--------------|
28+
| 1 | 0.416 | 96.4% |
29+
| 2 | 0.193 | 96.8% |
30+
| 3 | 0.158 | 97.2% |

llm/finetune/BEiT/BEiT_mind.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from mindspore.dataset import GeneratorDataset as ds_GeneratorDataset
2+
import numpy as np
3+
from sklearn.metrics import accuracy_score
4+
from mindnlp.engine import TrainingArguments, Trainer
5+
from mindnlp.transformers import AutoImageProcessor, BeitForImageClassification
6+
import mindspore
7+
from mindspore import Tensor
8+
from mindspore.dataset.transforms.py_transforms import Compose
9+
from mindspore.dataset.vision.py_transforms import (
10+
RandomResizedCrop, RandomHorizontalFlip, Resize, CenterCrop, ToTensor, Normalize
11+
)
12+
from datasets import load_dataset
13+
14+
# 加载数据集
15+
train_ds, test_ds = load_dataset(
16+
'uoft-cs/cifar10', split=['train[:5000]', 'test[:2000]'])
17+
splits = train_ds.train_test_split(test_size=0.1)
18+
train_ds_hf = splits['train']
19+
val_ds_hf = splits['test']
20+
test_ds_hf = test_ds
21+
22+
# 构造标签映射
23+
id2label = {id: label for id, label in enumerate(
24+
train_ds_hf.features['label'].names)}
25+
label2id = {label: id for id, label in id2label.items()}
26+
27+
# 初始化图像处理器
28+
processor = AutoImageProcessor.from_pretrained(
29+
'microsoft/beit-base-patch16-224')
30+
image_mean, image_std = processor.image_mean, processor.image_std
31+
size = processor.size["height"]
32+
33+
# 定义预处理流程
34+
normalize = Normalize(mean=image_mean, std=image_std)
35+
transform_train = Compose([
36+
RandomResizedCrop(size),
37+
RandomHorizontalFlip(),
38+
ToTensor(),
39+
normalize,
40+
])
41+
transform_val = Compose([
42+
Resize(size),
43+
CenterCrop(size),
44+
ToTensor(),
45+
normalize,
46+
])
47+
48+
# 定义Hugging Face数据变换
49+
50+
51+
def train_transforms(examples):
52+
examples['pixel_values'] = [transform_train(
53+
image.convert("RGB")) for image in examples['img']]
54+
return examples
55+
56+
57+
def val_transforms(examples):
58+
examples['pixel_values'] = [transform_val(
59+
image.convert("RGB")) for image in examples['img']]
60+
return examples
61+
62+
63+
# 应用transform到原始数据集
64+
train_ds_hf.set_transform(train_transforms)
65+
val_ds_hf.set_transform(val_transforms)
66+
test_ds_hf.set_transform(val_transforms)
67+
68+
# 创建 MindSpore Dataset
69+
70+
71+
def create_mindspore_dataset(hf_dataset):
72+
def generator():
73+
for example in hf_dataset:
74+
# 获取图像数据
75+
pixel_data = np.array(example['pixel_values'], dtype=np.float32)
76+
77+
# 中间打印调试
78+
# print("Raw pixel_data shape:", pixel_data.shape) # (C, H, W) (3, 224, 224)
79+
80+
# 处理图像数据维度
81+
if pixel_data.ndim == 4 and pixel_data.shape[0] == 1:
82+
# (1, C, H, W)
83+
pixel_data = pixel_data.squeeze(0)
84+
85+
yield pixel_data, np.int32(example['label'])
86+
87+
return ds_GeneratorDataset(
88+
generator,
89+
column_names=['pixel_values', 'labels'],
90+
column_types=[mindspore.float32, mindspore.int32]
91+
)
92+
93+
94+
# 创建数据集
95+
train_ds = create_mindspore_dataset(train_ds_hf).batch(
96+
10, drop_remainder=True) # 10个样本
97+
val_ds = create_mindspore_dataset(val_ds_hf).batch(4, drop_remainder=True)
98+
test_ds = create_mindspore_dataset(test_ds_hf).batch(4, drop_remainder=True)
99+
100+
# 中间打印调试
101+
# for batch in train_ds.create_tuple_iterator():
102+
# pixel_batch, label_batch = batch
103+
# print("Batch shape:", pixel_batch.shape) # 格式 (10, 3, 224, 224)
104+
# break
105+
106+
# 加载模型
107+
# 初始化训练参数
108+
args = TrainingArguments(
109+
output_dir="checkpoints",
110+
save_strategy="epoch",
111+
evaluation_strategy="epoch",
112+
learning_rate=2e-5,
113+
per_device_train_batch_size=10,
114+
per_device_eval_batch_size=4,
115+
num_train_epochs=3,
116+
weight_decay=0.01,
117+
load_best_model_at_end=True,
118+
metric_for_best_model="accuracy",
119+
logging_dir='logs',
120+
remove_unused_columns=False,
121+
max_grad_norm=0.0, # 禁用梯度裁剪 否则 Infer type failed.
122+
)
123+
124+
# 初始化模型
125+
model = BeitForImageClassification.from_pretrained(
126+
'microsoft/beit-base-patch16-224',
127+
num_labels=10,
128+
id2label=id2label,
129+
label2id=label2id,
130+
ignore_mismatched_sizes=True,
131+
132+
)
133+
134+
# 定义评估指标
135+
136+
137+
def compute_metrics(eval_pred):
138+
predictions, labels = eval_pred
139+
predictions = np.argmax(predictions, axis=1)
140+
return {"accuracy": accuracy_score(predictions, labels)}
141+
142+
143+
# 初始化Trainer
144+
trainer = Trainer(
145+
model=model,
146+
args=args,
147+
train_dataset=train_ds,
148+
eval_dataset=val_ds,
149+
compute_metrics=compute_metrics,
150+
tokenizer=processor,
151+
)
152+
153+
# 开始训练
154+
trainer.train()

0 commit comments

Comments
 (0)