Skip to content

Commit e3e02b8

Browse files
committed
add commandparser and more options in deepgaze1 sample
1 parent c756d5a commit e3e02b8

File tree

5 files changed

+58
-22
lines changed

5 files changed

+58
-22
lines changed

modules/datasets/samples/track_vot.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,4 @@ int main(int argc, char *argv[])
9696

9797
getchar();
9898
return 0;
99-
}
99+
}

modules/saliency/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ endif()
44

55
set(the_description "Saliency API")
66

7-
ocv_define_module(saliency opencv_imgproc opencv_features2d opencv_dnn WRAP python)
7+
ocv_define_module(saliency opencv_imgproc opencv_datasets opencv_features2d opencv_dnn WRAP python)
88

99
ocv_warnings_disable(CMAKE_CXX_FLAGS -Woverloaded-virtual)

modules/saliency/include/opencv2/saliency/saliencySpecializedClasses.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ class CV_EXPORTS_W DeepGaze1 : public StaticSaliency
166166
std::vector<double> weights;
167167

168168
public:
169-
DeepGaze1();
169+
DeepGaze1( std::string = "deploy.prototxt", std::string = "bvlc_alexnet.caffemodel" );
170170
DeepGaze1( std::string, std::string, std::vector<std::string>, unsigned );
171+
DeepGaze1( std::string, std::string, std::vector<std::string>, std::vector<double> );
171172
virtual ~DeepGaze1();
172173
CV_WRAP static Ptr<DeepGaze1> create()
173174
{

modules/saliency/samples/DeepGaze1Sample.cpp

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
// of this distribution and at http://opencv.org/license.html.
44

55
#include <opencv2/dnn.hpp>
6+
#include <opencv2/core.hpp>
67
#include <opencv2/imgproc.hpp>
78
#include <opencv2/highgui.hpp>
89
#include <opencv2/saliency.hpp>
9-
//#include <opencv2/datasets/saliency_mit1003.hpp>
10+
#include <opencv2/datasets/saliency_mit1003.hpp>
1011
#include <vector>
1112
#include <string>
1213
#include <iostream>
@@ -16,34 +17,61 @@ using namespace std;
1617
using namespace cv;
1718
using namespace cv::dnn;
1819
using namespace cv::saliency;
19-
//using namespace cv::datasets;
20-
/* Find best class for the blob (i. e. class with maximal probability) */
20+
using namespace cv::datasets;
2121

22-
int main()
22+
23+
int main(int argc, char* argv[])
2324
{
24-
DeepGaze1 g = DeepGaze1();
25-
vector<Mat> images;
26-
vector<Mat> fixs;
25+
const char *keys =
26+
"{ help h usage ? | | show this message }"
27+
"{ train |0 | set training on }"
28+
"{ default |1 | use default deep net(AlexNet) and default weights }"
29+
"{ img_path | | path to folder with img }"
30+
"{ fix_path | | path to folder with fixation img for compute AUC }"
31+
"{ model_path |bvlc_alexnet.caffemodel | path to your caffe model }"
32+
"{ proto_path |deploy.prototxt | path to your deep net caffe prototxt }"
33+
"{ dataset_path d |./ | path to Dataset for training }";
34+
35+
CommandLineParser parser(argc, argv, keys);
36+
if (parser.has("help"))
37+
{
38+
parser.printMessage();
39+
return 0;
40+
}
41+
string img_path = parser.get<string>("img_path");
42+
string model_path = parser.get<string>("model_path");
43+
string proto_path = parser.get<string>("proto_path");
44+
string dataset_path = parser.get<string>("dataset_path");
45+
string fix_path = parser.get<string>("fix_path");
46+
47+
DeepGaze1 g;
48+
if ( parser.get<bool>( "default" ) )
49+
{
50+
g = DeepGaze1( proto_path, model_path );
51+
}
52+
else
53+
{
54+
g = DeepGaze1( proto_path, model_path, vector<string>(1, "conv5"), 257 );
55+
}
2756

2857
//Download mit1003 saliency dataset in the working directory
2958
//ALLSTIMULI folder store images
3059
//ALLFIXATIONMAPS foler store training eye fixation
31-
//************ Code only work in linux platform ****
3260

33-
/* string dataset_path;
61+
if ( parser.get<bool>( "train" ) )
62+
{
63+
Ptr<SALIENCY_mit1003> datasetConnector = SALIENCY_mit1003::create();
64+
datasetConnector->load( dataset_path );
65+
vector<vector<Mat> > dataset( datasetConnector->getDataset() );
3466

35-
cin >> dataset_path;
36-
Ptr<SALIENCY_mit1003> datasetConnector = SALIENCY_mit1003::create();
37-
datasetConnector->load( dataset_path );
38-
vector<vector<Mat> > dataset( datasetConnector->getDataset() );
67+
g.training( dataset[0], dataset[1], 1, 200, 0.9, 0.000001, 0.01);
68+
}
3969

40-
g.training( dataset[0], dataset[1] );
41-
*/
4270
ofstream file;
4371
Mat res2;
44-
g.computeSaliency( imread( "ALLSTIMULI/i05june05_static_street_boston_p1010764.jpeg"), res2 );
72+
g.computeSaliency( imread( img_path ), res2 );
4573
resize( res2, res2, Size( 1024, 768 ) );
46-
cout << "AUC = " << g.computeAUC( res2, imread( "ALLFIXATIONMAPS/i05june05_static_street_boston_p1010764_fixMap.jpg", 0 ) ) << endl;;
74+
cout << "AUC = " << g.computeAUC( res2, imread( fix_path, 0 ) ) << endl;;
4775
g.saliencyMapVisualize( res2 );
4876
file.open( "saliency.csv" );
4977
for ( int i = 0; i < res2.rows; i++)

modules/saliency/src/DeepGaze1.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ namespace cv
2929
namespace saliency
3030
{
3131

32-
DeepGaze1::DeepGaze1()
32+
DeepGaze1::DeepGaze1( string net_proto, string net_caffemodel )
3333
{
34-
net = dnn::readNetFromCaffe( "deploy.prototxt", "bvlc_alexnet.caffemodel" );
34+
net = dnn::readNetFromCaffe( net_proto, net_caffemodel );
3535
layers_names.push_back("conv5");
3636
double tmp[] = {0.471836,0.459718,0.457611,0.480085,0.48823,0.462776,0.483445,0.461641,0.483469,0.47846,0.470172,0.383401,0.459975,0.478768,0.576305,0.469454,0.480448,0.435255,0.493571,0.48153,0.441619,0.45436,0.451791,0.476142,0.475955,0.468466,0.479836,0.530578,0.481301,0.454639,0.464725,0.45218,0.502245,0.44709,0.483469,0.500928,0.469562,0.46483,0.453324,0.461538,0.475741,0.480432,0.485595,0.462417,0.495092,0.471557,0.495046,0.459551,0.43668,0.505385,0.478419,0.492535,0.303292,0.475142,0.459992,0.454734,0.466868,0.450649,0.479587,0.434151,0.471309,0.460742,0.49318,0.524707,0.470968,0.478263,0.469935,0.459639,0.490684,0.465349,0.44842,0.481436,0.488862,0.468849,0.492233,0.467677,0.448416,0.474485,0.47684,0.492617,0.455164,0.46794,0.463009,0.47758,0.46629,0.495621,0.464325,0.473217,0.459664,0.478029,0.438637,0.447406,0.438148,0.455966,0.473499,0.473359,0.466213,0.525776,0.434224,0.464641,0.475869,0.501644,0.485892,0.483617,0.47226,0.482615,0.448091,0.460951,0.470457,0.469719,0.474948,0.516341,0.474467,0.429576,0.460061,0.446831,0.429813,0.479859,0.509008,0.504804,0.477351,0.461487,0.445481,0.44935,0.482019,0.469048,0.473205,0.460742,0.474685,0.461985,0.497119,0.464336,0.469783,0.464748,0.477133,0.484101,0.491574,0.591169,0.47327,0.467959,0.479773,0.465179,0.456533,0.42534,0.457655,0.474379,0.482501,0.491678,0.558077,0.473311,0.483722,0.474757,0.46874,0.459033,0.483051,0.475974,0.449861,0.456586,0.462686,0.46992,0.424458,0.492504,0.450006,0.468069,0.450585,0.442672,0.460277,0.460656,0.449303,0.470552,0.433665,0.47603,0.449626,0.471062,0.481555,0.427269,0.424295,0.588326,0.475818,0.484487,0.496265,0.480074,0.45834,0.469174,0.474869,0.49295,0.458737,0.461799,0.487588,0.488148,0.47734,0.480953,0.478616,0.470873,0.456516,0.461151,0.497269,0.449723,0.414189,0.473214,0.47472,0.478068,0.454312,0.485553,0.43564,0.469596,0.450846,0.488699,0.481056,0.419303,0.479696,0.471458,0.456179,0.465579,0.449656,0.459427,0.475431,0.518732,0.45971,0.51276,0.475805,0.467066,0.455423,0.462425,0.468577,0.429871,0.467098,0.467196,0.48245,0.496047,0.439613,0.446267,0.478326,0.463222,0.466251,0.475164,0.460792,0.407577,0.475157,0.465814,0.480478,0.490252,0.485834,0.455555,0.488025,0.472621,0.482393,0.48254,0.500558,0.466278,0.478975,0.423606,0.482795,0.486593,0.488191,0.483121,5.96006};
3737
weights = vector<double>( tmp, tmp + 257 );
@@ -48,6 +48,13 @@ DeepGaze1::DeepGaze1( string net_proto, string net_caffemodel, vector<string> se
4848
}
4949
}
5050

51+
DeepGaze1::DeepGaze1( string net_proto, string net_caffemodel, vector<string> selected_layers, vector<double> i_weights )
52+
{
53+
net = dnn::readNetFromCaffe( net_proto, net_caffemodel );
54+
layers_names = selected_layers;
55+
weights = i_weights;
56+
}
57+
5158
DeepGaze1::~DeepGaze1(){}
5259

5360
vector<Mat> DeepGaze1::featureMapGenerator( Mat img, Size input_size )

0 commit comments

Comments
 (0)