Skip to content

Commit 5cc6f54

Browse files
committed
resiliency to invalid models config, fix #107
1 parent bc38f28 commit 5cc6f54

File tree

3 files changed

+23
-24
lines changed

3 files changed

+23
-24
lines changed

camera.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,12 @@ def __init__(self):
8787
self._cnn_classifiers = {}
8888
cnn_model = config.Config.get().get("cnn_default_model", "")
8989
if cnn_model != "":
90-
self._cnn_classifiers[cnn_model] = CNNManager.get_instance().load_model(cnn_model)
91-
self._cnn_classifier_default = self._cnn_classifiers[cnn_model]
92-
logging.info("loaded: " + cnn_model + " " + str(self._cnn_classifier_default))
90+
try:
91+
self._cnn_classifiers[cnn_model] = CNNManager.get_instance().load_model(cnn_model)
92+
self._cnn_classifier_default = self._cnn_classifiers[cnn_model]
93+
logging.info("loaded: " + cnn_model + " " + str(self._cnn_classifier_default))
94+
except:
95+
logging.warning("model not found: " + cnn_model)
9396

9497
self._camera.grab_start()
9598
self._image_cv = self.get_image()
@@ -354,7 +357,7 @@ def find_ar_code(self):
354357
img = self.get_image()
355358
return img.find_ar_code()
356359

357-
def cnn_classify(self, model_name=None):
360+
def cnn_classify(self, model_name=None, top_results=3):
358361
classifier = None
359362
if model_name:
360363
classifier = self._cnn_classifiers.get(model_name)
@@ -364,12 +367,17 @@ def cnn_classify(self, model_name=None):
364367
else:
365368
classifier = self._cnn_classifier_default
366369

367-
img = self.get_image()
368-
classes = classifier.classify_image(img.mat())
370+
classes = None
371+
try:
372+
img = self.get_image()
373+
classes = classifier.classify_image(img.mat(), top_results=top_results)
374+
except:
375+
logging.warning("classifier not available")
376+
classes = [("None", 1.0)]
369377
return classes
370378

371379
def find_class(self):
372-
return self.cnn_classify()[0]
380+
return self.cnn_classify(top_results=1)[0][0]
373381

374382
def sleep(self, elapse):
375383
logging.debug("sleep: " + str(elapse))

cnn_classifier.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,13 @@ def load_labels(self, label_file):
110110

111111
def classify_image(self,
112112
image_file_or_mat,
113-
input_height=128,
114-
input_width=128,
115-
input_mean=0,
116-
input_std=255):
113+
top_results=3):
117114
s_t = time.time()
118115
t = None
119116
if type(image_file_or_mat) == str:
120-
t = self.read_tensor_from_image_file(file_name=image_file_or_mat,
121-
input_height=input_height,
122-
input_width=input_width,
123-
input_mean=input_mean,
124-
input_std=input_std)
117+
t = self.read_tensor_from_image_file(file_name=image_file_or_mat)
125118
else:
126-
t = self.read_tensor_from_image_mat(image_file_or_mat,
127-
input_height=input_height,
128-
input_width=input_width,
129-
input_mean=input_mean,
130-
input_std=input_std)
119+
t = self.read_tensor_from_image_mat(image_file_or_mat)
131120

132121
#logging.info( "time.norm: " + str(time.time() - s_t))
133122
s_t = time.time()
@@ -137,7 +126,9 @@ def classify_image(self,
137126

138127
#logging.info( "time.cls: " + str(time.time() - s_t))
139128

129+
top_results = min(top_results, len(self._labels))
140130
results = np.squeeze(results)
141-
result = results.argmax()
142-
pairs = (self._labels[result], results[result])
131+
results_idx = np.argpartition(results, -top_results)[-top_results:]
132+
results_idx = np.flip(results_idx[np.argsort(results[results_idx])], axis=0)
133+
pairs = [(self._labels[i], results[i]) for i in results_idx]
143134
return pairs

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
# Logging configuration
3636
logger = logging.getLogger()
37-
logger.setLevel(logging.WARNING)
37+
logger.setLevel(logging.INFO)
3838
sh = logging.StreamHandler()
3939
fh = logging.handlers.RotatingFileHandler('./logs/coderbot.log', maxBytes=1000000, backupCount=5)
4040
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

0 commit comments

Comments
 (0)