Releases: mpj1234/CNN-ML-Classify-PyTorch
Releases · mpj1234/CNN-ML-Classify-PyTorch
V1.0
20 Feb 06:50
Compare
Sorry, something went wrong.
No results found
CNN-ML-Classify-PyTorch
使用PyTorch预训练好的RNN模型来提取预分类特征,进而使用预分类特征在机器学习上进行训练和分类。
项目来源
项目借鉴于github上的(地址 )
此示例使用来自 Caltech 图像的 48 个标记图像的子集设置(http://www.vision.caltech.edu/Image_Datasets/Caltech101/),每个标签限制在 40 到 80
张图像之间。图像被CNN模型进行特征提取。
使用 2048-d 特征进行降维 t-distributed stochastic neighbor embedding
(t-SNE) t-分布随机邻域嵌入来转换它们成易于可视化的二维特征。请注意,使用 t-SNE
作为一个提供信息的步骤。如果相同的颜色/标签点大多是聚集在一起,我们很有可能可以使用这些功能训练一个高精度的分类器。
将 2048-d 标记的特征呈现给许多分类器。 最初,该项目是训练支持向量机 分类图像,但是为了比较,这已扩展到下列的:
Support Vector Machine (SVM, LinearSVC)
Extra Trees (ET)
Random Forest (RF)
K-Nearest Neighbor (KNN)
Multi-Layer Perceptron (ML)
Gaussian Naive Bayes (GNB)
Linear Discriminant Analysis (LDA)
Quadratic Discriminant Analysis (QDA)
安装教程
需要安装 PyTorch GPU 和 sklearn,最好使用GPU。 电脑显卡型号:NVIDIA GeForce
GTX1650,整个train.py
文件运行时间233.41s,整个test.py
文件运行时间61.64s。
使用说明
下载完代码后,首先直接在此目录下解压caltech_101_images.zip
,然后将图像放入./caltech_101_images/train
文件夹中,同时和train同一级目录创建test目录。
最终文件结构:
|
|-- caltech_101_images
| |-- train 解压放置的目录
| |-- test
可以直接使用datasetSegmentation.py
脚本进行数据分割,脚本默认是30%测试数据,脚本会将分割好的测试数据集从train
目录搬运到test
目录中。
之后可以使用train.py
脚本进行训练,脚本会将训练好的模型放到model
目录中,同时在训练中也会进行test的测试,这个测试是和test.py
这个脚本测试结果相同,test.py
这个脚本测试是需要train.py
训练完的模型。
可以使用c-svc_classify_features.py
脚本进行SVC的网格搜索,寻找较好的参数。
训练和测试的结果如下:
训练集图片数量: 2071;测试集图片数量:860.
每个CNN模型运行train.py
的时间和最好的正确率
模型
time
best acc
特征维度
Inception V3
194.68s
94.4%
2048
Resnet18
73.47s
91.9%
512
Resnet34
84.78s
94.0%
512
Resnet50
143.16s
93.3%
2048
ResNext50_32x4d
155.48s
94.9%
2048
DenseNet121
121.56s
96.9%
1024
MNASNet0_5
104.16s
91.7%
1280
MNASNet1_0
94.50s
93.0%
1280
Inception V3 模型
模型
时间
正确率
LinearSVC
5.00s
94.4%
SVC
12.72s
93.5%
Extra Tree
0.70s
88.5%
Random Forest
7.73s
85.1%
K-Nearest Neighbours
2.68s
87.1%
Multi-layer Perceptron
10.43s
94.1%
Gaussian Naive Bayes
1.10s
87.7%
Linear Discriminant Analysis
4.38s
31.9%
Quadratic Discriminant Analysis
1.06s
3.7%
Resnet18 模型
模型
时间
正确率
LinearSVC
2.33s
91.4%
SVC
2.90s
91.5%
Extra Tree
0.48s
85.8%
Random Forest
4.26s
82.2%
K-Nearest Neighbours
0.71s
80.9%
Multi-layer Perceptron
3.75s
91.9%
Gaussian Naive Bayes
0.35s
87.8%
Linear Discriminant Analysis
0.40s
90.7%
Quadratic Discriminant Analysis
0.51s
3.7%
Resnet34 模型
模型
时间
正确率
LinearSVC
1.94s
93.0%
SVC
2.78s
93.5%
Extra Tree
0.43s
90.1%
Random Forest
4.27s
86.6%
K-Nearest Neighbours
0.68s
87.9%
Multi-layer Perceptron
3.01s
93.7%
Gaussian Naive Bayes
0.27s
90.3%
Linear Discriminant Analysis
0.39s
94.0%
Quadratic Discriminant Analysis
0.46s
3.8%
Resnet50 模型
模型
时间
正确率
LinearSVC
8.35s
92.2%
SVC
10.15s
93.0%
Extra Tree
0.78s
88.8%
Random Forest
8.28s
87.2%
K-Nearest Neighbours
2.70s
84.2%
Multi-layer Perceptron
16.04s
93.3%
Gaussian Naive Bayes
1.07s
89.4%
Linear Discriminant Analysis
4.46s
31.9%
Quadratic Discriminant Analysis
1.04s
2.2%
ResNext50_32x4d 模型
模型
时间
正确率
LinearSVC
6.37s
94.5%
SVC
10.95s
93.7%
Extra Tree
0.79s
88.8%
Random Forest
8.38s
87.0%
K-Nearest Neighbours
2.64s
87.4%
Multi-layer Perceptron
14.96s
94.9%
Gaussian Naive Bayes
1.07s
91.0%
Linear Discriminant Analysis
4.45s
29.7%
Quadratic Discriminant Analysis
1.04s
3.0%
DenseNet121 模型
模型
时间
正确率
LinearSVC
2.71s
96.3%
SVC
5.72s
96.5%
Extra Tree
0.55s
92.7%
Random Forest
5.50s
91.2%
K-Nearest Neighbours
1.33s
92.7%
Multi-layer Perceptron
3.78s
96.3%
Gaussian Naive Bayes
0.53s
93.3%
Linear Discriminant Analysis
1.11s
96.9%
Quadratic Discriminant Analysis
0.74s
1.7%
MNASNet0_5 模型
模型
时间
正确率
LinearSVC
1.41s
91.7%
SVC
12.20s
56.3%
Extra Tree
0.50s
81.3%
Random Forest
3.72s
77.0%
K-Nearest Neighbours
1.76s
76.5%
Multi-layer Perceptron
10.10s
91.5%
Gaussian Naive Bayes
0.72s
51.0%
Linear Discriminant Analysis
2.25s
85.6%
Quadratic Discriminant Analysis
0.82s
2.1%
MNASNet1_0 模型
模型
时间
正确率
LinearSVC
2.68s
93.0%
SVC
6.96s
91.9%
Extra Tree
0.75s
86.4%
Random Forest
6.13s
83.7%
K-Nearest Neighbours
1.66s
84.1%
Multi-layer Perceptron
5.26s
93.0%
Gaussian Naive Bayes
0.67s
81.5%
Linear Discriminant Analysis
2.06s
88.5%
Quadratic Discriminant Analysis
0.77s
92.62%