Skip to content

Commit 79edbb2

Browse files
committed
unet added
1 parent 08f2f88 commit 79edbb2

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

Models/Segnet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ def segnetModel(nClasses , optimizer=None , input_height=360, input_width=480 ):
7979

8080

8181
model.add(Convolution2D( nClasses , 1, 1, border_mode='valid',))
82+
83+
model.outputHeight = model.output_shape[-2]
84+
model.outputWidth = model.output_shape[-1]
85+
86+
8287
model.add(Reshape(( nClasses , model.output_shape[-2]*model.output_shape[-1] ), input_shape=( nClasses , model.output_shape[-2], model.output_shape[-1] )))
8388

8489
model.add(Permute((2, 1)))

Models/Unet.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
2+
3+
4+
5+
from keras.models import Sequential
6+
from keras.layers import Reshape
7+
from keras.models import Model
8+
from keras.layers.core import Layer, Dense, Dropout, Activation, Flatten, Reshape, Merge, Permute
9+
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout
10+
from keras.layers.normalization import BatchNormalization
11+
from keras.layers.convolutional import Convolution3D, MaxPooling3D, ZeroPadding3D , ZeroPadding3D , UpSampling3D
12+
from keras.layers.convolutional import Convolution2D, MaxPooling2D, UpSampling2D, ZeroPadding2D
13+
from keras.layers.convolutional import Convolution1D, MaxPooling1D
14+
from keras.layers.recurrent import LSTM
15+
from keras.layers.advanced_activations import LeakyReLU
16+
from keras.optimizers import Adam , SGD
17+
from keras.layers.embeddings import Embedding
18+
from keras.utils import np_utils
19+
from keras.regularizers import ActivityRegularizer
20+
from keras import backend as K
21+
22+
23+
24+
25+
26+
def unet_2d (nClasses , optimizer=None , input_width=360 , input_height=480 , nChannels=1 ):
27+
28+
inputs = Input((nChannels, input_height, input_width))
29+
conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)
30+
conv1 = Dropout(0.2)(conv1)
31+
conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
32+
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
33+
34+
conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
35+
conv2 = Dropout(0.2)(conv2)
36+
conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
37+
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
38+
39+
conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
40+
conv3 = Dropout(0.2)(conv3)
41+
conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
42+
43+
up1 = merge([UpSampling2D(size=(2, 2))(conv3), conv2], mode='concat', concat_axis=1)
44+
conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up1)
45+
conv4 = Dropout(0.2)(conv4)
46+
conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv4)
47+
48+
up2 = merge([UpSampling2D(size=(2, 2))(conv4), conv1], mode='concat', concat_axis=1)
49+
conv5 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up2)
50+
conv5 = Dropout(0.2)(conv5)
51+
conv5 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv5)
52+
53+
conv6 = Convolution2D(nClasses, 1, 1, activation='relu',border_mode='same')(conv5)
54+
conv6 = core.Reshape((nClasses,input_height*input_width))(conv6)
55+
conv6 = core.Permute((2,1))(conv6)
56+
57+
58+
conv7 = core.Activation('softmax')(conv6)
59+
60+
model = Model(input=inputs, output=conv7)
61+
62+
if not optimizer is None:
63+
model.compile(loss="categorical_crossentropy", optimizer= optimizer , metrics=['accuracy'] )
64+
65+
return model
66+
67+
68+
69+

train.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33

44

55
import Models , LoadBatches
6-
G = LoadBatches.imageSegmentationGenerator( "data/clothes_seg/prepped/images_prepped_train/" , "data/clothes_seg/prepped/annotations_prepped_train/" , 1, 10 , 800 , 550 , 400 , 272 )
7-
G2 = LoadBatches.imageSegmentationGenerator( "data/clothes_seg/prepped/images_prepped_test/" , "data/clothes_seg/prepped/annotations_prepped_test/" , 1, 10 , 800 , 550 , 400 , 272 )
86

9-
m = Models.VGGSegnet.VGGSegnet( 10 , use_vgg_weights=True , optimizer='adadelta' , input_image_size=( 800 , 550 ) )
7+
8+
m = Models.VGGSegnet.VGGSegnet( 10 , use_vgg_weights=True , optimizer='adadelta' , input_image_size=( input_height , input_width ) )
9+
10+
output_height = m.outputHeight
11+
output_width = m.outputWidth
12+
13+
14+
G = LoadBatches.imageSegmentationGenerator( train_images_path , train_segs_path , train_batch_size, n_classes , input_height , input_width , output_height , output_width )
15+
G2 = LoadBatches.imageSegmentationGenerator( val_images_path , val_segs_path , val_batch_size, n_classes , input_height , input_width , output_height , output_width )
16+
17+
1018
m.fit_generator( G , 512 , nb_epoch=10 )
1119

1220

0 commit comments

Comments
 (0)