|
| 1 | +import argparse |
| 2 | +import Models , LoadBatches |
1 | 3 |
|
2 | 4 |
|
| 5 | +""" |
3 | 6 |
|
| 7 | +THEANO_FLAGS=device=gpu0,floatX=float32 python train.py \ |
| 8 | + --save_weights_path=weights/ex1 \ |
| 9 | + --train_images="data/clothes_seg/prepped/images_prepped_train/" \ |
| 10 | + --train_annotations="data/clothes_seg/prepped/annotations_prepped_train/" \ |
| 11 | + --n_classes=10 \ |
| 12 | + --input_height=800 \ |
| 13 | + --input_width=550 |
4 | 14 |
|
5 |
| -import Models , LoadBatches |
6 | 15 |
|
| 16 | +""" |
| 17 | + |
| 18 | +parser = argparse.ArgumentParser() |
| 19 | +parser.add_argument("--save_weights_path", type = str ) |
| 20 | +parser.add_argument("--train_images", type = str ) |
| 21 | +parser.add_argument("--train_annotations", type = str ) |
| 22 | +parser.add_argument("--n_classes", type=int ) |
| 23 | +parser.add_argument("--input_height", type=int , default = 224 ) |
| 24 | +parser.add_argument("--input_width", type=int , default = 224 ) |
| 25 | + |
| 26 | +parser.add_argument('--validate',action='store_false') |
| 27 | +parser.add_argument("--val_images", type = str , default = "") |
| 28 | +parser.add_argument("--val_annotations", type = str , default = "") |
| 29 | +parser.add_argument("--epochs", type = int, default = 5 ) |
| 30 | +parser.add_argument("--batch_size", type = int, default = 2 ) |
| 31 | +parser.add_argument("--val_batch_size", type = int, default = 2 ) |
| 32 | +parser.add_argument("--load_weights", type = str , default = "") |
| 33 | + |
| 34 | +args = parser.parse_args() |
| 35 | + |
| 36 | +train_images_path = args.train_images |
| 37 | +train_segs_path = args.train_annotations |
| 38 | +train_batch_size = args.batch_size |
| 39 | +n_classes = args.n_classes |
| 40 | +input_height = args.input_height |
| 41 | +input_width = args.input_width |
| 42 | +validate = args.validate |
| 43 | +save_weights_path = args.save_weights_path |
| 44 | +epochs = args.epochs |
| 45 | + |
| 46 | +if validate: |
| 47 | + val_images_path = args.val_images |
| 48 | + val_segs_path = args.val_annotations |
| 49 | + val_batch_size = args.val_batch_size |
7 | 50 |
|
8 | 51 | m = Models.VGGSegnet.VGGSegnet( 10 , use_vgg_weights=True , optimizer='adadelta' , input_image_size=( input_height , input_width ) )
|
9 | 52 |
|
10 | 53 | output_height = m.outputHeight
|
11 | 54 | output_width = m.outputWidth
|
12 | 55 |
|
13 |
| - |
14 | 56 | G = LoadBatches.imageSegmentationGenerator( train_images_path , train_segs_path , train_batch_size, n_classes , input_height , input_width , output_height , output_width )
|
15 |
| -G2 = LoadBatches.imageSegmentationGenerator( val_images_path , val_segs_path , val_batch_size, n_classes , input_height , input_width , output_height , output_width ) |
16 | 57 |
|
17 | 58 |
|
18 |
| -m.fit_generator( G , 512 , nb_epoch=10 ) |
| 59 | +if validate: |
| 60 | + G2 = LoadBatches.imageSegmentationGenerator( val_images_path , val_segs_path , val_batch_size, n_classes , input_height , input_width , output_height , output_width ) |
| 61 | + |
19 | 62 |
|
| 63 | +for ep in range( epochs ): |
| 64 | + m.fit_generator( G , 512 , nb_epoch=1 ) |
| 65 | + m.save_weights( save_weights_path + "." + str( ep ) ) |
20 | 66 |
|
0 commit comments