Skip to content

Commit 5d58a67

Browse files
committed
added cli utils
1 parent 79edbb2 commit 5d58a67

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

LoadBatches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def imageSegmentationGenerator( images_path , segs_path , batch_size, n_classe
6666

6767
assert len( images ) == len(segmentations)
6868
for im , seg in zip(images,segmentations):
69-
assert( im.split('/')[-1] == seg.split('/')[-1] )
69+
assert( im.split('/')[-1].split(".")[0] == seg.split('/')[-1].split(".")[0] )
7070

7171
zipped = itertools.cycle( zip(images,segmentations) )
7272

train.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,66 @@
1+
import argparse
2+
import Models , LoadBatches
13

24

5+
"""
36
7+
THEANO_FLAGS=device=gpu0,floatX=float32 python train.py \
8+
--save_weights_path=weights/ex1 \
9+
--train_images="data/clothes_seg/prepped/images_prepped_train/" \
10+
--train_annotations="data/clothes_seg/prepped/annotations_prepped_train/" \
11+
--n_classes=10 \
12+
--input_height=800 \
13+
--input_width=550
414
5-
import Models , LoadBatches
615
16+
"""
17+
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument("--save_weights_path", type = str )
20+
parser.add_argument("--train_images", type = str )
21+
parser.add_argument("--train_annotations", type = str )
22+
parser.add_argument("--n_classes", type=int )
23+
parser.add_argument("--input_height", type=int , default = 224 )
24+
parser.add_argument("--input_width", type=int , default = 224 )
25+
26+
parser.add_argument('--validate',action='store_false')
27+
parser.add_argument("--val_images", type = str , default = "")
28+
parser.add_argument("--val_annotations", type = str , default = "")
29+
parser.add_argument("--epochs", type = int, default = 5 )
30+
parser.add_argument("--batch_size", type = int, default = 2 )
31+
parser.add_argument("--val_batch_size", type = int, default = 2 )
32+
parser.add_argument("--load_weights", type = str , default = "")
33+
34+
args = parser.parse_args()
35+
36+
train_images_path = args.train_images
37+
train_segs_path = args.train_annotations
38+
train_batch_size = args.batch_size
39+
n_classes = args.n_classes
40+
input_height = args.input_height
41+
input_width = args.input_width
42+
validate = args.validate
43+
save_weights_path = args.save_weights_path
44+
epochs = args.epochs
45+
46+
if validate:
47+
val_images_path = args.val_images
48+
val_segs_path = args.val_annotations
49+
val_batch_size = args.val_batch_size
750

851
m = Models.VGGSegnet.VGGSegnet( 10 , use_vgg_weights=True , optimizer='adadelta' , input_image_size=( input_height , input_width ) )
952

1053
output_height = m.outputHeight
1154
output_width = m.outputWidth
1255

13-
1456
G = LoadBatches.imageSegmentationGenerator( train_images_path , train_segs_path , train_batch_size, n_classes , input_height , input_width , output_height , output_width )
15-
G2 = LoadBatches.imageSegmentationGenerator( val_images_path , val_segs_path , val_batch_size, n_classes , input_height , input_width , output_height , output_width )
1657

1758

18-
m.fit_generator( G , 512 , nb_epoch=10 )
59+
if validate:
60+
G2 = LoadBatches.imageSegmentationGenerator( val_images_path , val_segs_path , val_batch_size, n_classes , input_height , input_width , output_height , output_width )
61+
1962

63+
for ep in range( epochs ):
64+
m.fit_generator( G , 512 , nb_epoch=1 )
65+
m.save_weights( save_weights_path + "." + str( ep ) )
2066

0 commit comments

Comments
 (0)