Skip to content

Commit c9bf29f

Browse files
committed
UNET
1 parent bfcf985 commit c9bf29f

File tree

3 files changed

+53
-53
lines changed

3 files changed

+53
-53
lines changed

src/helpers.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,6 @@
33
import matplotlib.pyplot as plt
44
from skimage import color, io
55

6-
7-
def one_hot_encode(y, num_classes):
8-
"""
9-
One-hot encoding.
10-
11-
Args:
12-
- y : vector to encode.
13-
- num_classes: number of classes.
14-
"""
15-
N = y.shape[0]
16-
17-
encoded = np.zeros((N, num_classes))
18-
19-
for i, label in enumerate(y):
20-
encoded[i][label] = 1
21-
22-
return encoded
23-
24-
256
def load_batch(file_path):
267
"""
278
Loads data given path to file.
@@ -85,3 +66,8 @@ def save_lab_images(img_batch, filename="images/output_{}.png"):
8566
for i in range(lab_unscaled.shape[0]):
8667
rgb = color.lab2rgb(lab_unscaled[i])
8768
io.imsave(filename.format(i), rgb)
69+
70+
def save_gray_images(img_batch, filename="images/output_{}.png"):
71+
72+
for i in range(img_batch.shape[0]):
73+
io.imsave(filename.format(i), img_batch[i])

src/main.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,51 @@
77

88
SEED = 24
99

10-
X_train, X_val, _ = load_CIFAR(SEED)
10+
X_train, Y_train, X_val, Y_val, X_test, Y_test = load_CIFAR(SEED)
1111

1212
X_train = X_train[0:1000,:,:,:]
13-
X_val = X_val[0:50,:,:,:]
13+
Y_train = Y_train[0:1000,:,:,:]
1414

15-
print(X_train.shape)
16-
print(X_val.shape)
15+
X_val = X_val[0:100,:,:,:]
16+
Y_val = Y_val[0:100,:,:,:]
17+
18+
X_test = X_test[0:100,:,:,:]
19+
Y_test = Y_test[0:100,:,:,:]
20+
21+
print('Train:')
22+
print('X_train:', X_train.shape)
23+
print('Y_train:', Y_train.shape)
24+
25+
print('Validation:')
26+
print('X_val:', X_val.shape)
27+
print('Y_val:', Y_val.shape)
28+
29+
print('Test:')
30+
print('X_test:', X_test.shape)
31+
print('Y_test:', Y_test.shape)
32+
33+
save_gray_images(X_train[0:10,:,:,:], filename="images/train_before_gray_{}.png")
34+
save_lab_images(Y_train[0:10,:,:,:], filename="images/train_before_color_{}.png")
35+
36+
save_gray_images(X_val[0:10,:,:,:], filename="images/val_before_gray_{}.png")
37+
save_lab_images(Y_val[0:10,:,:,:], filename="images/val_before_color_{}.png")
1738

1839
np.random.seed(SEED)
1940
tf.random.set_random_seed(SEED)
2041

21-
2242
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
2343

2444
UNET = Model(sess, SEED)
25-
45+
2646
UNET.compile()
2747

28-
UNET.train(X_train, X_val)
48+
print('Training...')
49+
UNET.train(X_train, Y_train, X_val, Y_val)
2950

30-
UNET.predict(X_train)
31-
32-
"""print(MLP.evaluate(X_train, Y_train))
33-
print(MLP.evaluate(X_val, Y_val))
34-
print(MLP.evaluate(X_test, Y_test))
35-
"""
36-
51+
print('Predicting training set...')
52+
pred = UNET.predict(X_train)
53+
save_lab_images(pred[0:10,:,:,:], filename="images/after_train_{}.png")
54+
55+
print('Predicting validation set...')
56+
pred = UNET.predict(X_val)
57+
save_lab_images(pred[0:10,:,:,:], filename="images/after_val_{}.png")

src/model.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, sess, seed):
1414
# Training settings.
1515
self.learning_rate = 0.001
1616
self.num_epochs = 200
17-
self.batch_size = 100
17+
self.batch_size = 128
1818

1919
self.compiled = False
2020

@@ -35,15 +35,15 @@ def compile(self):
3535
self.compiled = True
3636

3737
# Placeholders.
38-
self.X_rgb = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='X_rgb')
39-
self.X_gray = tf.placeholder(tf.float32, shape=(None, 32, 32, 1), name='X_gray')
38+
self.X = tf.placeholder(tf.float32, shape=(None, 32, 32, 1), name='X')
39+
self.Y = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='Y')
4040

4141
# Model.
4242
net = UNet(self.seed)
43-
self.out = net.forward(self.X_rgb)
43+
self.out = net.forward(self.X)
4444

4545
# Loss and metrics.
46-
self.loss = tf.reduce_sum(tf.square(self.out - self.X_rgb))
46+
self.loss = tf.reduce_sum(tf.square(self.out - self.Y))
4747

4848
# Optimizer.
4949
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
@@ -53,7 +53,7 @@ def compile(self):
5353

5454
self.saver = tf.train.Saver()
5555

56-
def train(self, X_train, X_val):
56+
def train(self, X_train, Y_train, X_val, Y_val):
5757

5858
if not self.compiled:
5959
print('Compile model first.')
@@ -68,7 +68,7 @@ def train(self, X_train, X_val):
6868
merged = tf.summary.merge_all()
6969
date = str(datetime.datetime.now()).replace(" ", "_")[:19]
7070
train_writer = tf.summary.FileWriter('logs/' + date + '/train', self.sess.graph)
71-
val_writer = tf.summary.FileWriter('logs/' + date + '/val')
71+
val_writer = tf.summary.FileWriter('logs/' + date + '/val', self.sess.graph)
7272
train_writer.flush()
7373
val_writer.flush()
7474

@@ -81,16 +81,9 @@ def train(self, X_train, X_val):
8181
start = b * self.batch_size
8282
end = min(b * self.batch_size + self.batch_size, N)
8383
batch_x = X_train[start:end,:,:,:]
84+
batch_y = Y_train[start:end,:,:,:]
8485

85-
#if b != 0:
86-
# end_t_2 = timer()
87-
# print('data load: {0}'.format(end_t_2 - start_t_2))
88-
89-
#start_t = timer()
90-
_, l = self.sess.run([self.optimizer, self.loss], feed_dict={self.X_rgb: batch_x})
91-
#end_t = timer()
92-
#print('sess.run: {0}'.format(end_t - start_t))
93-
#start_t_2 = timer()
86+
_, l = self.sess.run([self.optimizer, self.loss], feed_dict={self.X: batch_x ,self.Y: batch_y})
9487

9588
epoch_loss += l / num_batches
9689

@@ -105,9 +98,9 @@ def train(self, X_train, X_val):
10598
train_writer.flush()
10699

107100
# Add validation loss to val log.
108-
#summary = self.sess.run(merged, feed_dict={self.X_rgb: X_val})
109-
#val_writer.add_summary(summary, epoch)
110-
#val_writer.flush()
101+
summary = self.sess.run(merged, feed_dict={self.X: X_val ,self.Y: Y_val})
102+
val_writer.add_summary(summary, epoch)
103+
val_writer.flush()
111104

112105
# Save model.
113106
if self.save and epoch % self.save_interval == 0:
@@ -136,4 +129,4 @@ def predict(self, X):
136129
print('Compile model first.')
137130
return
138131

139-
return self.sess.run(self.out, feed_dict={self.X_rgb: X})
132+
return self.sess.run(self.out, feed_dict={self.X: X})

0 commit comments

Comments
 (0)