Skip to content

Commit bc89900

Browse files
committed
Adding data augmentation classes
1 parent 53fbda0 commit bc89900

8 files changed

+876
-5
lines changed

odir.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2019 Jordi Corbilla. All Rights Reserved.
1+
# Copyright 2019-2020 Jordi Corbilla. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,15 +15,15 @@
1515
import numpy as np
1616

1717

18-
def load_data(image_size):
18+
def load_data(image_size, index):
1919
"""Loads the ODIR dataset.
2020
2121
Returns:
2222
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
2323
2424
"""
25-
x_train = np.load('odir_training'+'_' + str(image_size)+'.npy')
26-
y_train = np.load('odir_training_labels'+'_' + str(image_size)+'.npy')
25+
x_train = np.load('odir_training'+'_' + str(image_size) + '_' + str(index)+'.npy')
26+
y_train = np.load('odir_training_labels'+'_' + str(image_size) + '_' + str(index)+'.npy')
2727

2828
x_test = np.load('odir_testing'+'_' + str(image_size)+'.npy')
2929
y_test = np.load('odir_testing_labels'+'_' + str(image_size)+'.npy')

odir_advance_plotting.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Copyright 2019 Jordi Corbilla. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from __future__ import absolute_import, division, print_function, unicode_literals
16+
17+
import sys
18+
19+
import matplotlib.pyplot as plt
20+
from sklearn.metrics import confusion_matrix
21+
import numpy as np
22+
import seaborn as sns
23+
import matplotlib as mpl
24+
25+
26+
class Plotter:
27+
def __init__(self, class_names):
28+
self.class_names = class_names
29+
30+
def plot_metrics(self, history, test_run, index):
31+
metrics2 = ['loss', 'auc', 'precision', 'recall']
32+
for n, metric in enumerate(metrics2):
33+
name = metric.replace("_", " ").capitalize()
34+
plt.subplot(2, 2, n + 1)
35+
plt.plot(history.epoch, history.history[metric], color='green', label='Train')
36+
plt.plot(history.epoch, history.history['val_' + metric], color='green', linestyle="--", label='Val')
37+
plt.xlabel('Epoch')
38+
plt.ylabel(name)
39+
if metric == 'loss':
40+
plt.ylim([0, plt.ylim()[1]])
41+
elif metric == 'auc':
42+
plt.ylim([0.8, 1])
43+
else:
44+
plt.ylim([0, 1])
45+
46+
plt.legend()
47+
48+
plt.savefig('image_run' + str(index) + test_run + '.png')
49+
plt.show()
50+
plt.close()
51+
52+
def plot_input_images(self, x_train, y_train):
53+
plt.figure(figsize=(9, 9))
54+
for i in range(100):
55+
plt.subplot(10, 10, i + 1)
56+
plt.xticks([])
57+
plt.yticks([])
58+
plt.grid(False)
59+
plt.imshow(x_train[i])
60+
classes = ""
61+
for j in range(8):
62+
if y_train[i][j] >= 0.5:
63+
classes = classes + self.class_names[j] + "\n"
64+
plt.xlabel(classes, fontsize=7, color='black', labelpad=1)
65+
66+
plt.subplots_adjust(bottom=0.04, right=0.95, top=0.94, left=0.06, wspace=0.56, hspace=0.17)
67+
plt.show()
68+
69+
def plot_image(self, i, predictions_array, true_label, img):
70+
predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
71+
plt.grid(False)
72+
plt.xticks([])
73+
plt.yticks([])
74+
75+
plt.imshow(img)
76+
77+
ground = ""
78+
count_true = 0
79+
predicted_true = 0
80+
81+
for index in range(8):
82+
if true_label[index] >= 0.5:
83+
count_true = count_true + 1
84+
ground = ground + self.class_names[index] + "\n"
85+
if predictions_array[index] >= 0.5:
86+
predicted_true = predicted_true + 1
87+
88+
if count_true == predicted_true:
89+
color = 'green'
90+
else:
91+
color = 'red'
92+
93+
first, second, third, i, j, k = self.calculate_3_largest(predictions_array, 8)
94+
prediction = "{} {:2.0f}% \n".format(self.class_names[i], 100 * first)
95+
if second > 0.1:
96+
prediction = prediction + "{} {:2.0f}% \n".format(self.class_names[j], 100 * second)
97+
if third > 0.1:
98+
prediction = prediction + "{} {:2.0f}% \n".format(self.class_names[k], 100 * third)
99+
plt.xlabel("Predicted: {} Ground Truth: {}".format(prediction, ground), color=color)
100+
101+
def calculate_3_largest(self, arr, arr_size):
102+
if arr_size < 3:
103+
print(" Invalid Input ")
104+
return
105+
106+
third = first = second = -sys.maxsize
107+
index_1 = 0
108+
index_2 = 0
109+
index_3 = 0
110+
111+
for i in range(0, arr_size):
112+
if arr[i] > first:
113+
third = second
114+
second = first
115+
first = arr[i]
116+
elif arr[i] > second:
117+
third = second
118+
second = arr[i]
119+
elif arr[i] > third:
120+
third = arr[i]
121+
122+
for i in range(0, arr_size):
123+
if arr[i] == first:
124+
index_1 = i
125+
for i in range(0, arr_size):
126+
if arr[i] == second and i != index_1:
127+
index_2 = i
128+
for i in range(0, arr_size):
129+
if arr[i] == third and i != index_1 and i!= index_2:
130+
index_3 = i
131+
return first, second, third, index_1, index_2, index_3
132+
133+
def plot_value_array(self, i, predictions_array, true_label):
134+
predictions_array, true_label = predictions_array[i], true_label[i]
135+
plt.grid(False)
136+
plt.xticks([])
137+
plt.yticks([])
138+
bar_plot = plt.bar(range(8), predictions_array, color="#777777")
139+
plt.xticks(range(8), ('N', 'D', 'G', 'C', 'A', 'H', 'M', 'O'))
140+
plt.ylim([0, 1])
141+
142+
for j in range(8):
143+
if true_label[j] >= 0.5:
144+
bar_plot[j].set_color('green')
145+
146+
for j in range(8):
147+
if predictions_array[j] >= 0.5 and true_label[j] < 0.5:
148+
bar_plot[j].set_color('red')
149+
150+
def bar_label(rects):
151+
for rect in rects:
152+
height = rect.get_height()
153+
value = height * 100
154+
if value > 1:
155+
plt.annotate('{:2.0f}%'.format(value),
156+
xy=(rect.get_x() + rect.get_width() / 2, height),
157+
xytext=(0, 3), # 3 points vertical offset
158+
textcoords="offset points",
159+
ha='center', va='bottom')
160+
161+
bar_label(bar_plot)
162+
163+
def ensure_test_prediction_exists(self, predictions):
164+
exists = False
165+
for j in range(8):
166+
if predictions[j] >= 0.5:
167+
exists = True
168+
return exists
169+
170+
def plot_output(self, test_predictions_baseline, y_test, x_test_drawing):
171+
mpl.rcParams["font.size"] = 7
172+
num_rows = 5
173+
num_cols = 3
174+
num_images = num_rows * num_cols
175+
plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
176+
j = 0
177+
i = 0
178+
while j < num_images:
179+
if self.ensure_test_prediction_exists(test_predictions_baseline[i]):
180+
plt.subplot(num_rows, 2 * num_cols, 2 * j + 1)
181+
self.plot_image(i, test_predictions_baseline, y_test, x_test_drawing)
182+
plt.subplot(num_rows, 2 * num_cols, 2 * j + 2)
183+
self.plot_value_array(i, test_predictions_baseline, y_test)
184+
j = j + 1
185+
i = i + 1
186+
if i > 400:
187+
break
188+
189+
plt.subplots_adjust(bottom=0.08, right=0.95, top=0.94, left=0.05, wspace=0.11, hspace=0.56)
190+
plt.show()
191+
192+
def plot_output_single(self, i, test_predictions_baseline, y_test, x_test_drawing):
193+
plt.figure(figsize=(6, 3))
194+
plt.subplot(1, 2, 1)
195+
self.plot_image(i, test_predictions_baseline, y_test, x_test_drawing)
196+
plt.subplot(1, 2, 2)
197+
self.plot_value_array(i, test_predictions_baseline, y_test)
198+
plt.show()
199+
200+
def plot_confusion_matrix(self, y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues):
201+
"""
202+
This function prints and plots the confusion matrix.
203+
Normalization can be applied by setting `normalize=True`.
204+
"""
205+
if not title:
206+
if normalize:
207+
title = 'Normalized confusion matrix'
208+
else:
209+
title = 'Confusion matrix, without normalization'
210+
211+
# Compute confusion matrix
212+
cm = confusion_matrix(y_true.argmax(axis=1), y_pred.argmax(axis=1))
213+
# Only use the labels that appear in the data
214+
if normalize:
215+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
216+
print("Normalized confusion matrix")
217+
else:
218+
print('Confusion matrix, without normalization')
219+
220+
print(cm)
221+
222+
fig, ax = plt.subplots()
223+
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
224+
ax.figure.colorbar(im, ax=ax)
225+
# We want to show all ticks...
226+
ax.set(xticks=np.arange(cm.shape[1]),
227+
yticks=np.arange(cm.shape[0]),
228+
# ... and label them with the respective list entries
229+
# xticklabels=classes, yticklabels=classes,
230+
title=title,
231+
ylabel='True label',
232+
xlabel='Predicted label')
233+
ax.set_ylim(8.0, -1.0)
234+
# Rotate the tick labels and set their alignment.
235+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
236+
rotation_mode="anchor")
237+
238+
# Loop over data dimensions and create text annotations.
239+
fmt = '.2f' if normalize else 'd'
240+
thresh = cm.max() / 2.
241+
for i in range(cm.shape[0]):
242+
for j in range(cm.shape[1]):
243+
ax.text(j, i, format(cm[i, j], fmt),
244+
ha="center", va="center",
245+
color="white" if cm[i, j] > thresh else "black")
246+
fig.tight_layout()
247+
return ax
248+
249+
def print_normalized_confusion_matrix(self, y_test, test_predictions_baseline):
250+
np.set_printoptions(precision=2)
251+
252+
# Plot non-normalized confusion matrix
253+
self.plot_confusion_matrix(y_test, test_predictions_baseline, classes=self.class_names,
254+
title='Confusion matrix, without normalization')
255+
256+
# Plot normalized confusion matrix
257+
self.plot_confusion_matrix(y_test, test_predictions_baseline, classes=self.class_names, normalize=True,
258+
title='Normalized confusion matrix')
259+
260+
plt.show()
261+
262+
def plot_confusion_matrix_generic(self, labels2, predictions, test_run, p=0.5):
263+
cm = confusion_matrix(labels2.argmax(axis=1), predictions.argmax(axis=1))
264+
plt.figure(figsize=(6, 6))
265+
ax = sns.heatmap(cm, annot=True, fmt="d")
266+
ax.set_ylim(8.0, -1.0)
267+
plt.title('Confusion matrix')
268+
plt.ylabel('Actual label')
269+
plt.xlabel('Predicted label')
270+
plt.savefig('image_run3' + test_run + '.png')
271+
plt.show()
272+
plt.close()

0 commit comments

Comments
 (0)