Skip to content

Commit b502b48

Browse files
committed
fix cnn_classify
2 parents 4530263 + 9104d7e commit b502b48

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

cnn_classifier.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class CNNClassifier(object):
3131
def __init__(self, model_file, label_file, input_layer="input", output_layer="final_result", input_height=128, input_width=128, input_mean=127.5, input_std=127.5):
3232
self._graph = self.load_graph(model_file)
3333
self._labels = self.load_labels(label_file)
34+
self.input_height=input_height
35+
self.input_width=input_width
3436
input_name = "import/" + input_layer
3537
output_name = "import/" + output_layer
3638
self._input_operation = self._graph.get_operation_by_name(input_name)
@@ -78,8 +80,8 @@ def read_tensor_from_image_file(self, file_name, input_height=299, input_width=2
7880
image_reader = tf.image.decode_jpeg(file_reader, channels=3, name='jpeg_reader')
7981

8082
float_caster = tf.cast(image_reader, tf.float32)
81-
dims_expander = tf.expand_dims(float_caster, 0)
82-
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
83+
dims_expander = tf.expand_dims(float_caster, 0);
84+
resized = tf.image.resize_bilinear(dims_expander, [self.input_height, self.input_width])
8385
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
8486
sess = tf.Session()
8587

@@ -108,7 +110,6 @@ def classify_image(self,
108110
else:
109111
t = self.read_tensor_from_image_mat(image_file_or_mat)
110112

111-
112113
results = self._session.run(self._output_operation.outputs[0],
113114
{self._input_operation.outputs[0]: t})
114115

cnn_models/models.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"test_model_1": {"status": 1, "image_width": "128", "image_height": "128", "output_layer": "final_result"}, "test_model_2": {"status": 1, "image_width": "128", "image_height": "128", "output_layer": "final_result"}}
1+
{}

test/cnn_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ def test_train_set_1(self):
3737
time.sleep(1)
3838
self.assertTrue(cnn.get_models().get(name).get("status") == 1.0)
3939

40-
name="test_model_1"
4140
cnn = cnn_manager.CNNManager.get_instance()
4241
mod = cnn.load_model(name)
4342
result = mod.classify_image("photos/DSC86.jpg")
4443
print("result: " + str(result))
45-
self.assertTrue(result["kiwi"] == 1.0)
44+
self.assertTrue(result[0][0] == "kiwi" and result[0][1] == 1.0)
4645

47-
name="test_model_1"
4846
cnn = cnn_manager.CNNManager.get_instance()
4947
cnn.delete_model(name)
5048
self.assertTrue(cnn.get_models().get(name) is None)
@@ -72,7 +70,7 @@ def test_train_set_2(self):
7270
mod = cnn.load_model(name)
7371
result = mod.classify_image("photos/DSC86.jpg")
7472
print("result: " + str(result))
75-
self.assertTrue(result["kiwi"] >= 0.9)
73+
self.assertTrue(result[0][0] == "kiwi" and result[0][1] > 0.9)
7674

7775
cnn = cnn_manager.CNNManager.get_instance()
7876
cnn.delete_model(name)

0 commit comments

Comments
 (0)