Skip to content

Commit 9104d7e

Browse files
committed
fix cnn_classify
1 parent 92d2905 commit 9104d7e

File tree

3 files changed

+6
-12
lines changed

3 files changed

+6
-12
lines changed

cnn_classifier.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class CNNClassifier(object):
3737
def __init__(self, model_file, label_file, input_layer="input", output_layer="final_result", input_height=128, input_width=128, input_mean=127.5, input_std=127.5):
3838
self._graph = self.load_graph(model_file)
3939
self._labels = self.load_labels(label_file)
40+
self.input_height=input_height
41+
self.input_width=input_width
4042
input_name = "import/" + input_layer
4143
output_name = "import/" + output_layer
4244
self._input_operation = self._graph.get_operation_by_name(input_name);
@@ -88,7 +90,7 @@ def read_tensor_from_image_file(self, file_name, input_height=299, input_width=2
8890

8991
float_caster = tf.cast(image_reader, tf.float32)
9092
dims_expander = tf.expand_dims(float_caster, 0);
91-
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
93+
resized = tf.image.resize_bilinear(dims_expander, [self.input_height, self.input_width])
9294
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
9395
sess = tf.Session()
9496

@@ -111,21 +113,15 @@ def load_labels(self, label_file):
111113
def classify_image(self,
112114
image_file_or_mat,
113115
top_results=3):
114-
s_t = time.time()
115116
t = None
116117
if type(image_file_or_mat) == str:
117118
t = self.read_tensor_from_image_file(file_name=image_file_or_mat)
118119
else:
119120
t = self.read_tensor_from_image_mat(image_file_or_mat)
120121

121-
#logging.info( "time.norm: " + str(time.time() - s_t))
122-
s_t = time.time()
123-
124122
results = self._session.run(self._output_operation.outputs[0],
125123
{self._input_operation.outputs[0]: t})
126124

127-
#logging.info( "time.cls: " + str(time.time() - s_t))
128-
129125
top_results = min(top_results, len(self._labels))
130126
results = np.squeeze(results)
131127
results_idx = np.argpartition(results, -top_results)[-top_results:]

cnn_models/models.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"test_model_1": {"status": 1, "image_width": "128", "image_height": "128", "output_layer": "final_result"}, "test_model_2": {"status": 1, "image_width": "128", "image_height": "128", "output_layer": "final_result"}}
1+
{}

test/cnn_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ def test_train_set_1(self):
3737
time.sleep(1)
3838
self.assertTrue(cnn.get_models().get(name).get("status") == 1.0)
3939

40-
name="test_model_1"
4140
cnn = cnn_manager.CNNManager.get_instance()
4241
mod = cnn.load_model(name)
4342
result = mod.classify_image("photos/DSC86.jpg")
4443
print("result: " + str(result))
45-
self.assertTrue(result["kiwi"] == 1.0)
44+
self.assertTrue(result[0][0] == "kiwi" and result[0][1] == 1.0)
4645

47-
name="test_model_1"
4846
cnn = cnn_manager.CNNManager.get_instance()
4947
cnn.delete_model(name)
5048
self.assertTrue(cnn.get_models().get(name) is None)
@@ -72,7 +70,7 @@ def test_train_set_2(self):
7270
mod = cnn.load_model(name)
7371
result = mod.classify_image("photos/DSC86.jpg")
7472
print("result: " + str(result))
75-
self.assertTrue(result["kiwi"] >= 0.9)
73+
self.assertTrue(result[0][0] == "kiwi" and result[0][1] > 0.9)
7674

7775
cnn = cnn_manager.CNNManager.get_instance()
7876
cnn.delete_model(name)

0 commit comments

Comments
 (0)