@@ -31,6 +31,8 @@ class CNNClassifier(object):
31
31
def __init__ (self , model_file , label_file , input_layer = "input" , output_layer = "final_result" , input_height = 128 , input_width = 128 , input_mean = 127.5 , input_std = 127.5 ):
32
32
self ._graph = self .load_graph (model_file )
33
33
self ._labels = self .load_labels (label_file )
34
+ self .input_height = input_height
35
+ self .input_width = input_width
34
36
input_name = "import/" + input_layer
35
37
output_name = "import/" + output_layer
36
38
self ._input_operation = self ._graph .get_operation_by_name (input_name )
@@ -78,8 +80,8 @@ def read_tensor_from_image_file(self, file_name, input_height=299, input_width=2
78
80
image_reader = tf .image .decode_jpeg (file_reader , channels = 3 , name = 'jpeg_reader' )
79
81
80
82
float_caster = tf .cast (image_reader , tf .float32 )
81
- dims_expander = tf .expand_dims (float_caster , 0 )
82
- resized = tf .image .resize_bilinear (dims_expander , [input_height , input_width ])
83
+ dims_expander = tf .expand_dims (float_caster , 0 );
84
+ resized = tf .image .resize_bilinear (dims_expander , [self . input_height , self . input_width ])
83
85
normalized = tf .divide (tf .subtract (resized , [input_mean ]), [input_std ])
84
86
sess = tf .Session ()
85
87
@@ -108,7 +110,6 @@ def classify_image(self,
108
110
else :
109
111
t = self .read_tensor_from_image_mat (image_file_or_mat )
110
112
111
-
112
113
results = self ._session .run (self ._output_operation .outputs [0 ],
113
114
{self ._input_operation .outputs [0 ]: t })
114
115
0 commit comments