Skip to content

Commit 40a12ac

Browse files
committed
implement #107
1 parent c3868ef commit 40a12ac

File tree

11 files changed

+106
-55
lines changed

11 files changed

+106
-55
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ photos/*.mp4
3838
photos/*.h264
3939
photos/*.json
4040

41+
#cnn models
42+
cnn_models/*
43+
cnn_models/cache/*
44+
4145
# Sounds recorded
4246
sounds/*.wav
4347

camera.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ def __init__(self):
8484
self._photos.append({'name': filename})
8585
self.save_photo_metadata()
8686

87-
#self._cnn_classifiers = {}
88-
#cnn_model = config.Config.get().get("cnn_default_model", "")
89-
#if cnn_model != "":
90-
# self._cnn_classifiers[cnn_model] = CNNManager.get_instance().load_model(cnn_model)
91-
# self._cnn_classifier_default = self._cnn_classifiers[cnn_model]
87+
self._cnn_classifiers = {}
88+
cnn_model = config.Config.get().get("cnn_default_model", "")
89+
if cnn_model != "":
90+
self._cnn_classifiers[cnn_model] = CNNManager.get_instance().load_model(cnn_model)
91+
self._cnn_classifier_default = self._cnn_classifiers[cnn_model]
9292

9393
self._camera.grab_start()
9494
self._image_cv = self.get_image()

cnn_classifier.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import time
2828
import logging
2929

30+
import operator
3031
import numpy as np
3132
import tensorflow as tf
3233

@@ -137,10 +138,7 @@ def classify_image(self,
137138
#logging.info( "time.cls: " + str(time.time() - s_t))
138139

139140
results = np.squeeze(results)
140-
141-
pairs = {}
142-
for i in results.argsort():
143-
pairs[self._labels[i]] = results[i]
144-
145-
#logging.info(pairs)
141+
result = results.argmax()
142+
pairs = {self._labels[result]: results[result]}
143+
logging.info(pairs)
146144
return pairs

cnn_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def train_new_model(self,
9393

9494
def save_model_status(self, model_name, architecture, status):
9595
model_info = architecture.split("_")
96-
self._models[model_name] = {"status": status, "image_height": model_info[2], "image_width": model_info[2]}
96+
self._models[model_name] = {"status": status, "image_height": model_info[3], "image_width": model_info[3], "output_layer": "final_result"}
9797
self._save_model_meta()
9898

9999
def wait_train_jobs(self):
@@ -105,6 +105,7 @@ def load_model(self, model_name):
105105
if model_info:
106106
return CNNClassifier(model_file = MODEL_PATH + "/" + model_name + ".pb",
107107
label_file = MODEL_PATH + "/" + model_name + ".txt",
108+
output_layer=model_info["output_layer"],
108109
input_height = int(model_info["image_height"]),
109110
input_width = int(model_info["image_width"]))
110111

cnn_models/models.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{}
1+
{"generic_fast_low":{"status":1.0, "image_height": "128", "image_width":"128", "output_layer": "MobilenetV2/Predictions/Reshape_1"}, "generic_slow_high":{"status":1.0, "image_height":"224", "image_width": "224", "output_layer": "MobilenetV2/Predictions/Reshape_1"}}

cnn_train.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(self, manager, architecture):
142142
self.validation_batch_size = 100
143143
self.print_misclassified_test_images = False
144144
self.bottleneck_dir = "/tmp/bottleneck"
145-
self.model_dir = "/tmp/imagenet"
145+
self.model_dir = "./cnn_models/cache"
146146
self.final_tensor_name = "final_result"
147147
self.write_logs = False
148148

@@ -160,7 +160,7 @@ def __init__(self, manager, architecture):
160160
raise Exception("Did not recognize architecture flag'")
161161

162162
# Set up the pre-trained graph.
163-
self.maybe_download_and_extract(self.model_info['data_url'])
163+
self.maybe_download_and_extract(self.model_info['data_url'], self.model_info['model_dir_name'])
164164
self.graph, self.bottleneck_tensor, self.resized_image_tensor = (
165165
self.create_model_graph(self.model_info))
166166

@@ -517,7 +517,7 @@ def run_bottleneck_on_image(self, sess, image_data, image_data_tensor,
517517
return bottleneck_values
518518

519519

520-
def maybe_download_and_extract(self, data_url):
520+
def maybe_download_and_extract(self, data_url, model_dir_name):
521521
"""Download and extract model tar file.
522522
523523
If the pretrained model we're using doesn't already exist, this function
@@ -526,7 +526,7 @@ def maybe_download_and_extract(self, data_url):
526526
Args:
527527
data_url: Web location of the tar file containing the pretrained model.
528528
"""
529-
dest_directory = self.model_dir
529+
dest_directory = os.path.join(self.model_dir, model_dir_name)
530530
if not os.path.exists(dest_directory):
531531
os.makedirs(dest_directory)
532532
filename = data_url.split('/')[-1]
@@ -538,11 +538,10 @@ def _progress(count, block_size, total_size):
538538
(filename,
539539
float(count * block_size) / float(total_size) * 100.0))
540540
sys.stdout.flush()
541-
542541
filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
543542
print()
544543
statinfo = os.stat(filepath)
545-
tf.logging.info('Successfully downloaded', filename, statinfo.st_size,
544+
tf.logging.info('Successfully downloaded %s %d', filename, statinfo.st_size,
546545
'bytes.')
547546
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
548547

@@ -1084,47 +1083,64 @@ def create_model_info(self, architecture):
10841083
tf.logging.error("Couldn't understand architecture name '%s'",
10851084
architecture)
10861085
return None
1087-
version_string = parts[1]
1086+
v_string = parts[1]
1087+
version_string = parts[2]
10881088
if (version_string != '1.0' and version_string != '0.75' and
1089-
version_string != '0.50' and version_string != '0.25'):
1089+
version_string != '0.50' and version_string != '0.5' and
1090+
version_string != '0.35' and version_string != '0.25'):
10901091
tf.logging.error(
1091-
""""The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
1092+
""""The Mobilenet version should be '1.0', '0.75', '0.50', '0.35', or '0.25',
10921093
but found '%s' for architecture '%s'""",
10931094
version_string, architecture)
10941095
return None
1095-
size_string = parts[2]
1096+
size_string = parts[3]
10961097
if (size_string != '224' and size_string != '192' and
10971098
size_string != '160' and size_string != '128'):
10981099
tf.logging.error(
10991100
"""The Mobilenet input size should be '224', '192', '160', or '128',
11001101
but found '%s' for architecture '%s'""",
11011102
size_string, architecture)
11021103
return None
1103-
if len(parts) == 3:
1104+
if len(parts) == 4:
11041105
is_quantized = False
11051106
else:
1106-
if parts[3] != 'quantized':
1107+
if parts[4] != 'quantized':
11071108
tf.logging.error(
11081109
"Couldn't understand architecture suffix '%s' for '%s'", parts[3],
11091110
architecture)
11101111
return None
11111112
is_quantized = True
1112-
data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
1113-
data_url += version_string + '_' + size_string + '_frozen.tgz'
1114-
bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
1113+
data_url = 'http://'
1114+
model_file_name = None
1115+
bottleneck_tensor_name = None
1116+
if architecture.startswith('mobilenet_v1'):
1117+
data_url += 'download.tensorflow.org/models/mobilenet_v1_'
1118+
data_url += version_string + '_' + size_string + '_frozen.tgz'
1119+
bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
1120+
if is_quantized:
1121+
model_base_name = 'quantized_graph.pb'
1122+
else:
1123+
model_base_name = 'frozen_graph.pb'
1124+
model_dir_name = 'mobilenet_v1_'
1125+
model_dir_name += version_string + '_' + size_string
1126+
model_file_name = os.path.join(model_dir_name, model_base_name)
1127+
model_dir_name = ''
1128+
else:
1129+
data_url += 'storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_'
1130+
data_url += version_string + '_' + size_string + '.tgz'
1131+
bottleneck_tensor_name = 'MobilenetV2/Predictions/Reshape:0'
1132+
model_dir_name = 'mobilenet_v2_'
1133+
model_dir_name += version_string + '_' + size_string
1134+
model_base_name = model_dir_name + '_frozen.pb'
1135+
model_file_name = os.path.join(model_dir_name, model_base_name)
11151136
bottleneck_tensor_size = 1001
11161137
input_width = int(size_string)
11171138
input_height = int(size_string)
11181139
input_depth = 3
11191140
resized_input_tensor_name = 'input:0'
1120-
if is_quantized:
1121-
model_base_name = 'quantized_graph.pb'
1122-
else:
1123-
model_base_name = 'frozen_graph.pb'
1124-
model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
1125-
model_file_name = os.path.join(model_dir_name, model_base_name)
11261141
input_mean = 127.5
11271142
input_std = 127.5
1143+
print(data_url)
11281144
else:
11291145
tf.logging.error("Couldn't understand architecture name '%s'", architecture)
11301146
raise ValueError('Unknown architecture', architecture)
@@ -1138,6 +1154,7 @@ def create_model_info(self, architecture):
11381154
'input_depth': input_depth,
11391155
'resized_input_tensor_name': resized_input_tensor_name,
11401156
'model_file_name': model_file_name,
1157+
'model_dir_name': model_dir_name,
11411158
'input_mean': input_mean,
11421159
'input_std': input_std,
11431160
}
@@ -1170,11 +1187,3 @@ def add_jpeg_decoding(self, input_width, input_height, input_depth, input_mean,
11701187
mul_image = tf.multiply(offset_image, 1.0 / input_std)
11711188
return jpeg_data, mul_image
11721189

1173-
1174-
if __name__ == '__main__':
1175-
cnn_trainer = CNNTrainer("mobilenet_0.50_128")
1176-
cnn_trainer.retrain(
1177-
image_dir="/home/pi/tensorflow/data/applekiwi",
1178-
output_graph="./cnn_models/applewiki_0_5_128.pb",
1179-
training_steps=10,
1180-
learning_rate=0.1)

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
# Logging configuration
3636
logger = logging.getLogger()
37-
logger.setLevel(logging.WARNING)
37+
logger.setLevel(logging.DEBUG)
3838
sh = logging.StreamHandler()
3939
fh = logging.handlers.RotatingFileHandler('./logs/coderbot.log', maxBytes=1000000, backupCount=5)
4040
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -96,7 +96,7 @@ def handle_home():
9696
config=app.bot_config,
9797
program_level=app.bot_config.get("prog_level", "std"),
9898
cam=cam != None,
99-
cnn_model_names = json.dumps({}))
99+
cnn_model_names = json.dumps([[name] for name in cnn.get_models().keys()]))
100100

101101
@babel.localeselector
102102
def get_locale():

photos/metadata.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
[]
1+
[{"name": "DSC2.jpg"}]

scripts/interfaces_ap

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
auto lo
2+
allow-hotplug wlan0
3+
iface lo inet loopback
4+
iface eth0 inet dhcp
5+
iface wlan0 inet static
6+
address 10.0.0.1
7+
netmask 255.255.255.0
8+
wireless-power off
9+
network 10.0.0.0
10+
broadcast 10.0.0.255

templates/config.html

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,14 @@ <h3>{% trans %}Train new CNN Model{% endtrans %}</h3>
251251
<input type="text" id="i_cnn_image_tags" name="cnn_image_tags" value="">
252252
<label for="i_cnn_model_arch">{% trans %}Model architecture{% endtrans %}</label>
253253
<select name="cnn_model_arch" id="i_cnn_model_arch" value="">
254-
     <option value="mobilenet_1.0_128">MobileNet 1.0 128</option>
255-
     <option value="mobilenet_0.75_128">MobileNet 0.75 128</option>
256-
     <option value="mobilenet_0.50_128">MobileNet 0.50 128</option>
257-
     <option value="mobilenet_0.25_128">MobileNet 0.25 128</option>
258-
     <option value="mobilenet_1.0_224">MobileNet 1.0 224</option>
259-
     <option value="mobilenet_0.75_224">MobileNet 0.75 224</option>
260-
     <option value="mobilenet_0.50_224">MobileNet 0.50 224</option>
261-
     <option value="mobilenet_0.25_224">MobileNet 0.25 224</option>
254+
     <option value="mobilenet_v2_1.0_128">MobileNet v2 1.0 128</option>
255+
     <option value="mobilenet_v2_0.35_128">MobileNet v2 0.35 128</option>
256+
     <option value="mobilenet_v2_1.0_224">MobileNet v2 1.0 224</option>
257+
     <option value="mobilenet_v2_0.35_224">MobileNet v2 0.35 224</option>
258+
     <option value="mobilenet_1.0_128">MobileNet v1 1.0 128</option>
259+
     <option value="mobilenet_0.25_128">MobileNet v1 0.25 128</option>
260+
     <option value="mobilenet_1.0_224">MobileNet v1 1.0 224</option>
261+
     <option value="mobilenet_0.25_224">MobileNet v1 0.25 224</option>
262262
</select>
263263
<label for="i_cnn_model_train_steps">{% trans %}Training steps{% endtrans %}</label>
264264
<input type="range" id="i_cnn_train_steps" name="cnn_train_steps" min="10" max="100" step="10" value="50">

test/cnn_test.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def test_train_set_1(self):
2323
cam = camera.Camera.get_instance()
2424
cnn = cnn_manager.CNNManager.get_instance()
2525
cnn.train_new_model(model_name=name,
26-
architecture="mobilenet_0.25_128",
26+
architecture="mobilenet_v1_0.50_128",
2727
image_tags=["other","tomato","apple","kiwi"],
2828
photos_meta=cam.get_photo_list(),
2929
training_steps=20,
3030
learning_rate=0.1)
3131
start = time.time()
3232
while True:
3333
model_status = cnn.get_models().get(name, {"status": 0})
34-
if model_status.get("status") == 1 or time.time() - start > 600:
34+
if model_status.get("status") == 1 or time.time() - start > 6000:
3535
break
3636
print("status: " + str(model_status["status"]))
3737
time.sleep(1)
@@ -48,3 +48,32 @@ def test_train_set_1(self):
4848
cnn = cnn_manager.CNNManager.get_instance()
4949
cnn.delete_model(name)
5050
self.assertTrue(cnn.get_models().get(name) is None)
51+
52+
def test_train_set_2(self):
53+
name="test_model_2"
54+
cam = camera.Camera.get_instance()
55+
cnn = cnn_manager.CNNManager.get_instance()
56+
cnn.train_new_model(model_name=name,
57+
architecture="mobilenet_v2_0.5_128",
58+
image_tags=["other","tomato","apple","kiwi"],
59+
photos_meta=cam.get_photo_list(),
60+
training_steps=20,
61+
learning_rate=0.1)
62+
start = time.time()
63+
while True:
64+
model_status = cnn.get_models().get(name, {"status": 0})
65+
if model_status.get("status") == 1 or time.time() - start > 6000:
66+
break
67+
print("status: " + str(model_status["status"]))
68+
time.sleep(1)
69+
self.assertTrue(cnn.get_models().get(name).get("status") == 1.0)
70+
71+
cnn = cnn_manager.CNNManager.get_instance()
72+
mod = cnn.load_model(name)
73+
result = mod.classify_image("photos/DSC86.jpg")
74+
print("result: " + str(result))
75+
self.assertTrue(result["kiwi"] >= 0.9)
76+
77+
cnn = cnn_manager.CNNManager.get_instance()
78+
cnn.delete_model(name)
79+
self.assertTrue(cnn.get_models().get(name) is None)

0 commit comments

Comments
 (0)