使用预训练的ResNet50,损失函数使用ArcFace,数据集使用Youtube Faces
- 下载Youtube Faces中的aligned_images_DB.tar.gz并解压,接着在config/settings.py中设置数据集路径
- 下载dlib 68关键点检测模型并解压,放在files/shape_predictor_68_face_landmarks.dat,你可以在config/settings.py中设置该参数。你也可以设置其他参数。
- 安装依赖
注意:如果你有GPU并且想使用CUDA版本的PyTorch,请参考PyTorch官网安装
pip install -r requirements.txt
- 训练模型
python train.py
- 数据集组织 testdataraw/人名/图片
- 识别测试
在训练完成后,你可以使用提供的
test.py脚本进行识别测试:python test.py --img /path/to/your/image.*
MIT