@@ -14,7 +14,7 @@ def __init__(self, sess, seed):
14
14
# Training settings.
15
15
self .learning_rate = 0.001
16
16
self .num_epochs = 200
17
- self .batch_size = 100
17
+ self .batch_size = 128
18
18
19
19
self .compiled = False
20
20
@@ -35,15 +35,15 @@ def compile(self):
35
35
self .compiled = True
36
36
37
37
# Placeholders.
38
- self .X_rgb = tf .placeholder (tf .float32 , shape = (None , 32 , 32 , 3 ), name = 'X_rgb ' )
39
- self .X_gray = tf .placeholder (tf .float32 , shape = (None , 32 , 32 , 1 ), name = 'X_gray ' )
38
+ self .X = tf .placeholder (tf .float32 , shape = (None , 32 , 32 , 1 ), name = 'X ' )
39
+ self .Y = tf .placeholder (tf .float32 , shape = (None , 32 , 32 , 3 ), name = 'Y ' )
40
40
41
41
# Model.
42
42
net = UNet (self .seed )
43
- self .out = net .forward (self .X_rgb )
43
+ self .out = net .forward (self .X )
44
44
45
45
# Loss and metrics.
46
- self .loss = tf .reduce_sum (tf .square (self .out - self .X_rgb ))
46
+ self .loss = tf .reduce_sum (tf .square (self .out - self .Y ))
47
47
48
48
# Optimizer.
49
49
self .optimizer = tf .train .AdamOptimizer (learning_rate = self .learning_rate ).minimize (self .loss )
@@ -53,7 +53,7 @@ def compile(self):
53
53
54
54
self .saver = tf .train .Saver ()
55
55
56
- def train (self , X_train , X_val ):
56
+ def train (self , X_train , Y_train , X_val , Y_val ):
57
57
58
58
if not self .compiled :
59
59
print ('Compile model first.' )
@@ -68,7 +68,7 @@ def train(self, X_train, X_val):
68
68
merged = tf .summary .merge_all ()
69
69
date = str (datetime .datetime .now ()).replace (" " , "_" )[:19 ]
70
70
train_writer = tf .summary .FileWriter ('logs/' + date + '/train' , self .sess .graph )
71
- val_writer = tf .summary .FileWriter ('logs/' + date + '/val' )
71
+ val_writer = tf .summary .FileWriter ('logs/' + date + '/val' , self . sess . graph )
72
72
train_writer .flush ()
73
73
val_writer .flush ()
74
74
@@ -81,16 +81,9 @@ def train(self, X_train, X_val):
81
81
start = b * self .batch_size
82
82
end = min (b * self .batch_size + self .batch_size , N )
83
83
batch_x = X_train [start :end ,:,:,:]
84
+ batch_y = Y_train [start :end ,:,:,:]
84
85
85
- #if b != 0:
86
- # end_t_2 = timer()
87
- # print('data load: {0}'.format(end_t_2 - start_t_2))
88
-
89
- #start_t = timer()
90
- _ , l = self .sess .run ([self .optimizer , self .loss ], feed_dict = {self .X_rgb : batch_x })
91
- #end_t = timer()
92
- #print('sess.run: {0}'.format(end_t - start_t))
93
- #start_t_2 = timer()
86
+ _ , l = self .sess .run ([self .optimizer , self .loss ], feed_dict = {self .X : batch_x ,self .Y : batch_y })
94
87
95
88
epoch_loss += l / num_batches
96
89
@@ -105,9 +98,9 @@ def train(self, X_train, X_val):
105
98
train_writer .flush ()
106
99
107
100
# Add validation loss to val log.
108
- # summary = self.sess.run(merged, feed_dict={self.X_rgb : X_val})
109
- # val_writer.add_summary(summary, epoch)
110
- # val_writer.flush()
101
+ summary = self .sess .run (merged , feed_dict = {self .X : X_val , self . Y : Y_val })
102
+ val_writer .add_summary (summary , epoch )
103
+ val_writer .flush ()
111
104
112
105
# Save model.
113
106
if self .save and epoch % self .save_interval == 0 :
@@ -136,4 +129,4 @@ def predict(self, X):
136
129
print ('Compile model first.' )
137
130
return
138
131
139
- return self .sess .run (self .out , feed_dict = {self .X_rgb : X })
132
+ return self .sess .run (self .out , feed_dict = {self .X : X })
0 commit comments