20
20
This module implements the CNNClassifier class, which is the interface for
21
21
using an existing and trained CNN model.
22
22
"""
23
- from __future__ import absolute_import
24
- from __future__ import division
25
- from __future__ import print_function
26
-
27
- import time
28
23
import logging
29
24
30
- import operator
31
25
import numpy as np
32
26
import tensorflow as tf
33
27
@@ -39,17 +33,17 @@ def __init__(self, model_file, label_file, input_layer="input", output_layer="fi
39
33
self ._labels = self .load_labels (label_file )
40
34
input_name = "import/" + input_layer
41
35
output_name = "import/" + output_layer
42
- self ._input_operation = self ._graph .get_operation_by_name (input_name );
43
- self ._output_operation = self ._graph .get_operation_by_name (output_name );
36
+ self ._input_operation = self ._graph .get_operation_by_name (input_name )
37
+ self ._output_operation = self ._graph .get_operation_by_name (output_name )
44
38
self ._session = tf .Session (graph = self ._graph )
45
39
self ._graph_norm = tf .Graph ()
46
40
with self ._graph_norm .as_default ():
47
41
image_mat = tf .placeholder (tf .float32 , None , name = "image_rgb_in" )
48
42
float_caster = tf .cast (image_mat , tf .float32 )
49
- dims_expander = tf .expand_dims (float_caster , 0 );
43
+ dims_expander = tf .expand_dims (float_caster , 0 )
50
44
resized = tf .image .resize_bilinear (dims_expander , [input_height , input_width ])
51
45
normalized = tf .divide (tf .subtract (resized , [input_mean ]), [input_std ], name = "image_norm_out" )
52
- self ._input_operation_norm = self ._graph_norm .get_operation_by_name ("image_rgb_in" )
46
+ self ._input_operation_norm = self ._graph_norm .get_operation_by_name ("image_rgb_in" )
53
47
self ._output_operation_norm = self ._graph_norm .get_operation_by_name ("image_norm_out" )
54
48
self ._sess_norm = tf .Session (graph = self ._graph_norm )
55
49
@@ -75,19 +69,16 @@ def read_tensor_from_image_file(self, file_name, input_height=299, input_width=2
75
69
file_reader = tf .read_file (file_name , input_name )
76
70
77
71
if file_name .endswith (".png" ):
78
- image_reader = tf .image .decode_png (file_reader , channels = 3 ,
79
- name = 'png_reader' )
72
+ image_reader = tf .image .decode_png (file_reader , channels = 3 , name = 'png_reader' )
80
73
elif file_name .endswith (".gif" ):
81
- image_reader = tf .squeeze (tf .image .decode_gif (file_reader ,
82
- name = 'gif_reader' ))
74
+ image_reader = tf .squeeze (tf .image .decode_gif (file_reader , name = 'gif_reader' ))
83
75
elif file_name .endswith (".bmp" ):
84
76
image_reader = tf .image .decode_bmp (file_reader , name = 'bmp_reader' )
85
77
else :
86
- image_reader = tf .image .decode_jpeg (file_reader , channels = 3 ,
87
- name = 'jpeg_reader' )
78
+ image_reader = tf .image .decode_jpeg (file_reader , channels = 3 , name = 'jpeg_reader' )
88
79
89
80
float_caster = tf .cast (image_reader , tf .float32 )
90
- dims_expander = tf .expand_dims (float_caster , 0 );
81
+ dims_expander = tf .expand_dims (float_caster , 0 )
91
82
resized = tf .image .resize_bilinear (dims_expander , [input_height , input_width ])
92
83
normalized = tf .divide (tf .subtract (resized , [input_mean ]), [input_std ])
93
84
sess = tf .Session ()
@@ -111,20 +102,15 @@ def load_labels(self, label_file):
111
102
def classify_image (self ,
112
103
image_file_or_mat ,
113
104
top_results = 3 ):
114
- s_t = time .time ()
115
105
t = None
116
- if type (image_file_or_mat ) == str :
106
+ if isinstance (image_file_or_mat , str ) :
117
107
t = self .read_tensor_from_image_file (file_name = image_file_or_mat )
118
108
else :
119
109
t = self .read_tensor_from_image_mat (image_file_or_mat )
120
110
121
- #logging.info( "time.norm: " + str(time.time() - s_t))
122
- s_t = time .time ()
123
111
124
112
results = self ._session .run (self ._output_operation .outputs [0 ],
125
- {self ._input_operation .outputs [0 ]: t })
126
-
127
- #logging.info( "time.cls: " + str(time.time() - s_t))
113
+ {self ._input_operation .outputs [0 ]: t })
128
114
129
115
top_results = min (top_results , len (self ._labels ))
130
116
results = np .squeeze (results )
0 commit comments