@@ -37,6 +37,8 @@ class CNNClassifier(object):
37
37
def __init__ (self , model_file , label_file , input_layer = "input" , output_layer = "final_result" , input_height = 128 , input_width = 128 , input_mean = 127.5 , input_std = 127.5 ):
38
38
self ._graph = self .load_graph (model_file )
39
39
self ._labels = self .load_labels (label_file )
40
+ self .input_height = input_height
41
+ self .input_width = input_width
40
42
input_name = "import/" + input_layer
41
43
output_name = "import/" + output_layer
42
44
self ._input_operation = self ._graph .get_operation_by_name (input_name );
@@ -88,7 +90,7 @@ def read_tensor_from_image_file(self, file_name, input_height=299, input_width=2
88
90
89
91
float_caster = tf .cast (image_reader , tf .float32 )
90
92
dims_expander = tf .expand_dims (float_caster , 0 );
91
- resized = tf .image .resize_bilinear (dims_expander , [input_height , input_width ])
93
+ resized = tf .image .resize_bilinear (dims_expander , [self . input_height , self . input_width ])
92
94
normalized = tf .divide (tf .subtract (resized , [input_mean ]), [input_std ])
93
95
sess = tf .Session ()
94
96
@@ -111,21 +113,15 @@ def load_labels(self, label_file):
111
113
def classify_image (self ,
112
114
image_file_or_mat ,
113
115
top_results = 3 ):
114
- s_t = time .time ()
115
116
t = None
116
117
if type (image_file_or_mat ) == str :
117
118
t = self .read_tensor_from_image_file (file_name = image_file_or_mat )
118
119
else :
119
120
t = self .read_tensor_from_image_mat (image_file_or_mat )
120
121
121
- #logging.info( "time.norm: " + str(time.time() - s_t))
122
- s_t = time .time ()
123
-
124
122
results = self ._session .run (self ._output_operation .outputs [0 ],
125
123
{self ._input_operation .outputs [0 ]: t })
126
124
127
- #logging.info( "time.cls: " + str(time.time() - s_t))
128
-
129
125
top_results = min (top_results , len (self ._labels ))
130
126
results = np .squeeze (results )
131
127
results_idx = np .argpartition (results , - top_results )[- top_results :]
0 commit comments