|
| 1 | +from mindspore.dataset import GeneratorDataset as ds_GeneratorDataset |
| 2 | +import numpy as np |
| 3 | +from sklearn.metrics import accuracy_score |
| 4 | +from mindnlp.engine import TrainingArguments, Trainer |
| 5 | +from mindnlp.transformers import AutoImageProcessor, BeitForImageClassification |
| 6 | +import mindspore |
| 7 | +from mindspore import Tensor |
| 8 | +from mindspore.dataset.transforms.py_transforms import Compose |
| 9 | +from mindspore.dataset.vision.py_transforms import ( |
| 10 | + RandomResizedCrop, RandomHorizontalFlip, Resize, CenterCrop, ToTensor, Normalize |
| 11 | +) |
| 12 | +from datasets import load_dataset |
| 13 | + |
| 14 | +# 加载数据集 |
| 15 | +train_ds, test_ds = load_dataset( |
| 16 | + 'uoft-cs/cifar10', split=['train[:5000]', 'test[:2000]']) |
| 17 | +splits = train_ds.train_test_split(test_size=0.1) |
| 18 | +train_ds_hf = splits['train'] |
| 19 | +val_ds_hf = splits['test'] |
| 20 | +test_ds_hf = test_ds |
| 21 | + |
| 22 | +# 构造标签映射 |
| 23 | +id2label = {id: label for id, label in enumerate( |
| 24 | + train_ds_hf.features['label'].names)} |
| 25 | +label2id = {label: id for id, label in id2label.items()} |
| 26 | + |
| 27 | +# 初始化图像处理器 |
| 28 | +processor = AutoImageProcessor.from_pretrained( |
| 29 | + 'microsoft/beit-base-patch16-224') |
| 30 | +image_mean, image_std = processor.image_mean, processor.image_std |
| 31 | +size = processor.size["height"] |
| 32 | + |
| 33 | +# 定义预处理流程 |
| 34 | +normalize = Normalize(mean=image_mean, std=image_std) |
| 35 | +transform_train = Compose([ |
| 36 | + RandomResizedCrop(size), |
| 37 | + RandomHorizontalFlip(), |
| 38 | + ToTensor(), |
| 39 | + normalize, |
| 40 | +]) |
| 41 | +transform_val = Compose([ |
| 42 | + Resize(size), |
| 43 | + CenterCrop(size), |
| 44 | + ToTensor(), |
| 45 | + normalize, |
| 46 | +]) |
| 47 | + |
| 48 | +# 定义Hugging Face数据变换 |
| 49 | + |
| 50 | + |
| 51 | +def train_transforms(examples): |
| 52 | + examples['pixel_values'] = [transform_train( |
| 53 | + image.convert("RGB")) for image in examples['img']] |
| 54 | + return examples |
| 55 | + |
| 56 | + |
| 57 | +def val_transforms(examples): |
| 58 | + examples['pixel_values'] = [transform_val( |
| 59 | + image.convert("RGB")) for image in examples['img']] |
| 60 | + return examples |
| 61 | + |
| 62 | + |
| 63 | +# 应用transform到原始数据集 |
| 64 | +train_ds_hf.set_transform(train_transforms) |
| 65 | +val_ds_hf.set_transform(val_transforms) |
| 66 | +test_ds_hf.set_transform(val_transforms) |
| 67 | + |
| 68 | +# 创建 MindSpore Dataset |
| 69 | + |
| 70 | + |
| 71 | +def create_mindspore_dataset(hf_dataset): |
| 72 | + def generator(): |
| 73 | + for example in hf_dataset: |
| 74 | + # 获取图像数据 |
| 75 | + pixel_data = np.array(example['pixel_values'], dtype=np.float32) |
| 76 | + |
| 77 | + # 中间打印调试 |
| 78 | + # print("Raw pixel_data shape:", pixel_data.shape) # (C, H, W) (3, 224, 224) |
| 79 | + |
| 80 | + # 处理图像数据维度 |
| 81 | + if pixel_data.ndim == 4 and pixel_data.shape[0] == 1: |
| 82 | + # (1, C, H, W) |
| 83 | + pixel_data = pixel_data.squeeze(0) |
| 84 | + |
| 85 | + yield pixel_data, np.int32(example['label']) |
| 86 | + |
| 87 | + return ds_GeneratorDataset( |
| 88 | + generator, |
| 89 | + column_names=['pixel_values', 'labels'], |
| 90 | + column_types=[mindspore.float32, mindspore.int32] |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +# 创建数据集 |
| 95 | +train_ds = create_mindspore_dataset(train_ds_hf).batch( |
| 96 | + 10, drop_remainder=True) # 10个样本 |
| 97 | +val_ds = create_mindspore_dataset(val_ds_hf).batch(4, drop_remainder=True) |
| 98 | +test_ds = create_mindspore_dataset(test_ds_hf).batch(4, drop_remainder=True) |
| 99 | + |
| 100 | +# 中间打印调试 |
| 101 | +# for batch in train_ds.create_tuple_iterator(): |
| 102 | +# pixel_batch, label_batch = batch |
| 103 | +# print("Batch shape:", pixel_batch.shape) # 格式 (10, 3, 224, 224) |
| 104 | +# break |
| 105 | + |
| 106 | +# 加载模型 |
| 107 | +# 初始化训练参数 |
| 108 | +args = TrainingArguments( |
| 109 | + output_dir="checkpoints", |
| 110 | + save_strategy="epoch", |
| 111 | + evaluation_strategy="epoch", |
| 112 | + learning_rate=2e-5, |
| 113 | + per_device_train_batch_size=10, |
| 114 | + per_device_eval_batch_size=4, |
| 115 | + num_train_epochs=3, |
| 116 | + weight_decay=0.01, |
| 117 | + load_best_model_at_end=True, |
| 118 | + metric_for_best_model="accuracy", |
| 119 | + logging_dir='logs', |
| 120 | + remove_unused_columns=False, |
| 121 | + max_grad_norm=0.0, # 禁用梯度裁剪 否则 Infer type failed. |
| 122 | +) |
| 123 | + |
| 124 | +# 初始化模型 |
| 125 | +model = BeitForImageClassification.from_pretrained( |
| 126 | + 'microsoft/beit-base-patch16-224', |
| 127 | + num_labels=10, |
| 128 | + id2label=id2label, |
| 129 | + label2id=label2id, |
| 130 | + ignore_mismatched_sizes=True, |
| 131 | + |
| 132 | +) |
| 133 | + |
| 134 | +# 定义评估指标 |
| 135 | + |
| 136 | + |
| 137 | +def compute_metrics(eval_pred): |
| 138 | + predictions, labels = eval_pred |
| 139 | + predictions = np.argmax(predictions, axis=1) |
| 140 | + return {"accuracy": accuracy_score(predictions, labels)} |
| 141 | + |
| 142 | + |
| 143 | +# 初始化Trainer |
| 144 | +trainer = Trainer( |
| 145 | + model=model, |
| 146 | + args=args, |
| 147 | + train_dataset=train_ds, |
| 148 | + eval_dataset=val_ds, |
| 149 | + compute_metrics=compute_metrics, |
| 150 | + tokenizer=processor, |
| 151 | +) |
| 152 | + |
| 153 | +# 开始训练 |
| 154 | +trainer.train() |
0 commit comments