一个最简单的PyTorch图像分类示例,使用CIFAR-10数据集训练一个基础的卷积神经网络。
- ✅ 完全自动化: 所有数据和依赖都会自动下载
- ✅ GPU支持: 自动检测并使用GPU加速训练
- ✅ 简单易懂: 代码结构清晰,适合初学者
- ✅ 完整流程: 包含数据加载、模型训练、测试和推理
pip install -r requirements.txt
python train.py
训练过程会:
- 自动下载CIFAR-10数据集
- 训练一个简单的CNN模型
- 每个epoch显示训练进度
- 保存最佳模型到
best_model.pth
python test.py
测试会:
- 加载训练好的模型
- 在测试集上评估准确率
- 显示详细的分类报告
python inference.py
推理会:
- 随机选择测试集中的图片
- 显示预测结果和真实标签
- 可视化预测过程
image/
├── README.md # 项目说明文档
├── requirements.txt # Python依赖包
├── train.py # 训练脚本
├── test.py # 测试脚本
├── inference.py # 推理脚本
├── model.py # 模型定义
├── dataset.py # 数据加载器
└── utils.py # 工具函数
使用简单的CNN架构:
- 2个卷积层 + ReLU + MaxPool
- 1个全连接层
- Dropout防止过拟合
- 10个输出类别(CIFAR-10)
CIFAR-10数据集:
- 60,000张32x32彩色图片
- 10个类别:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车
- 50,000张训练图片 + 10,000张测试图片
- 数据会自动下载到
./data
目录
- 学习率: 0.001
- 批次大小: 64
- 训练轮数: 10
- 优化器: Adam
- 损失函数: CrossEntropyLoss
- 训练时间:约5-10分钟(GPU)/ 20-30分钟(CPU)
- 预期准确率:60-70%
- 模型大小:约2MB
- Python 3.7+
- PyTorch 1.9+
- CUDA(可选,用于GPU加速)
# 检查CUDA是否可用
python -c "import torch; print(torch.cuda.is_available())"
# 检查GPU设备
python -c "import torch; print(torch.cuda.get_device_name(0))"
如果遇到内存不足,可以:
- 减小批次大小(修改train.py中的batch_size)
- 使用CPU训练(会自动检测)
如果下载数据集失败:
- 检查网络连接
- 可能需要科学上网
- 数据会缓存,重新运行即可
- 数据增强: 添加随机翻转、旋转等
- 模型改进: 使用ResNet、VGG等预训练模型
- 超参数调优: 尝试不同的学习率、批次大小
- 可视化: 添加训练曲线、混淆矩阵等
MIT License - 可自由使用和修改
欢迎提交Issue和Pull Request!
Happy Coding! 🎉