Skip to content

Commit 9036961

Browse files
committed
a
1 parent 5c2f55a commit 9036961

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

train.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
--save_weights_path=weights/ex1 \
99
--train_images="data/clothes_seg/prepped/images_prepped_train/" \
1010
--train_annotations="data/clothes_seg/prepped/annotations_prepped_train/" \
11+
--val_images="data/clothes_seg/prepped/images_prepped_test/" \
12+
--val_annotations="data/clothes_seg/prepped/annotations_prepped_test/" \
1113
--n_classes=10 \
1214
--input_height=800 \
1315
--input_width=550
@@ -59,8 +61,13 @@
5961
if validate:
6062
G2 = LoadBatches.imageSegmentationGenerator( val_images_path , val_segs_path , val_batch_size, n_classes , input_height , input_width , output_height , output_width )
6163

64+
if not validate:
65+
for ep in range( epochs ):
66+
m.fit_generator( G , 512 , nb_epoch=1 )
67+
m.save_weights( save_weights_path + "." + str( ep ) )
68+
else:
69+
for ep in range( epochs ):
70+
m.fit_generator( G , 512 , validation_data=G2 , nb_val_samples=200 , nb_epoch=1 )
71+
m.save_weights( save_weights_path + "." + str( ep ) )
6272

63-
for ep in range( epochs ):
64-
m.fit_generator( G , 512 , nb_epoch=1 )
65-
m.save_weights( save_weights_path + "." + str( ep ) )
6673

0 commit comments

Comments
 (0)