Skip to content

Commit 41e2dc5

Browse files
author
Divam Gupta
committed
readme and models added
1 parent dc10b72 commit 41e2dc5

File tree

11 files changed

+650
-96
lines changed

11 files changed

+650
-96
lines changed

LoadBatches.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import itertools
66

77

8-
def getImageArr( path , width , height , imgNorm="sub_mean" ):
8+
def getImageArr( path , width , height , imgNorm="sub_mean" , odering='channels_first' ):
99

1010
try:
1111
img = cv2.imread(path, 1)
@@ -23,12 +23,14 @@ def getImageArr( path , width , height , imgNorm="sub_mean" ):
2323
img = img.astype(np.float32)
2424
img = img/255.0
2525

26-
img = np.rollaxis(img, 2, 0)
26+
if odering == 'channels_first':
27+
img = np.rollaxis(img, 2, 0)
2728
return img
2829
except Exception, e:
2930
print path , e
3031
img = np.zeros(( height , width , 3 ))
31-
img = np.rollaxis(img, 2, 0)
32+
if odering == 'channels_first':
33+
img = np.rollaxis(img, 2, 0)
3234
return img
3335

3436

Models/FCN.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

Models/FCN32.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/models/fcn32s.py
3+
# assert 0 == 1 # fc weights into the 1x1 convs , get_upsampling_weight
4+
5+
6+
7+
from keras.models import *
8+
from keras.layers import *
9+
10+
11+
import os
12+
file_path = os.path.dirname( os.path.abspath(__file__) )
13+
14+
VGG_Weights_path = file_path+"/../../data/vgg16_weights_th_dim_ordering_th_kernels.h5"
15+
16+
17+
# for input(360,480) output will be ( 170 , 240)
18+
19+
# input_image_size -> ( height , width )
20+
21+
22+
def FCN32( nClasses , input_height=416, input_width=608 , vgg_level=3):
23+
24+
assert input_height%32 == 0
25+
assert input_width%32 == 0
26+
27+
# https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_th_dim_ordering_th_kernels.h5
28+
n_classes = 3
29+
img_input = Input(shape=(3,input_height,input_width))
30+
31+
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', data_format='channels_first' )(img_input)
32+
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', data_format='channels_first' )(x)
33+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool', data_format='channels_first' )(x)
34+
f1 = x
35+
# Block 2
36+
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', data_format='channels_first' )(x)
37+
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', data_format='channels_first' )(x)
38+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool', data_format='channels_first' )(x)
39+
f2 = x
40+
41+
# Block 3
42+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', data_format='channels_first' )(x)
43+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', data_format='channels_first' )(x)
44+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', data_format='channels_first' )(x)
45+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool', data_format='channels_first' )(x)
46+
f3 = x
47+
48+
# Block 4
49+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', data_format='channels_first' )(x)
50+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', data_format='channels_first' )(x)
51+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', data_format='channels_first' )(x)
52+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool', data_format='channels_first' )(x)
53+
f4 = x
54+
55+
# Block 5
56+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', data_format='channels_first' )(x)
57+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', data_format='channels_first' )(x)
58+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', data_format='channels_first' )(x)
59+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool', data_format='channels_first' )(x)
60+
f5 = x
61+
62+
x = Flatten(name='flatten')(x)
63+
x = Dense(4096, activation='relu', name='fc1')(x)
64+
x = Dense(4096, activation='relu', name='fc2')(x)
65+
x = Dense( 1024 , activation='softmax', name='predictions')(x)
66+
67+
vgg = Model( img_input , x )
68+
# vgg.load_weights(VGG_Weights_path)
69+
70+
o = f5
71+
72+
o = ( Conv2D( 4096 , ( 7 , 7 ) , activation='relu' , padding='same', data_format='channels_first'))(o)
73+
o = Dropout(0.5)(o)
74+
o = ( Conv2D( 4096 , ( 1 , 1 ) , activation='relu' , padding='same', data_format='channels_first'))(o)
75+
o = Dropout(0.5)(o)
76+
77+
o = ( Conv2D( nClasses , ( 1 , 1 ) ,kernel_initializer='he_normal' , data_format='channels_first'))(o)
78+
o = Conv2DTranspose( nClasses , kernel_size=(64,64) , strides=(32,32) , use_bias=False , data_format='channels_first' )(o)
79+
o_shape = Model(img_input , o ).output_shape
80+
81+
outputHeight = o_shape[2]
82+
outputWidth = o_shape[3]
83+
84+
print "koko" , o_shape
85+
86+
o = (Reshape(( -1 , outputHeight*outputWidth )))(o)
87+
# o = (Permute((2, 1)))(o)
88+
# o = (Activation('softmax'))(o)
89+
model = Model( img_input , o )
90+
model.outputWidth = outputWidth
91+
model.outputHeight = outputHeight
92+
93+
return model
94+
95+
96+
if __name__ == '__main__':
97+
m = FCN32( 101 )
98+
from keras.utils import plot_model
99+
plot_model( m , show_shapes=True , to_file='model.png')

Models/FCN8.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
2+
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/models/fcn32s.py
3+
# assert 0 == 1 # fc weights into the 1x1 convs , get_upsampling_weight
4+
5+
6+
7+
from keras.models import *
8+
from keras.layers import *
9+
10+
11+
import os
12+
file_path = os.path.dirname( os.path.abspath(__file__) )
13+
14+
VGG_Weights_path = file_path+"/../../data/vgg16_weights_th_dim_ordering_th_kernels.h5"
15+
16+
17+
# for input(360,480) output will be ( 170 , 240)
18+
19+
# input_image_size -> ( height , width )
20+
21+
# crop o1 wrt o2
22+
def crop( o1 , o2 , i ):
23+
o_shape2 = Model( i , o2 ).output_shape
24+
outputHeight2 = o_shape2[2]
25+
outputWidth2 = o_shape2[3]
26+
27+
o_shape1 = Model( i , o1 ).output_shape
28+
outputHeight1 = o_shape1[2]
29+
outputWidth1 = o_shape1[3]
30+
31+
cx = abs( outputWidth1 - outputWidth2 )
32+
cy = abs( outputHeight2 - outputHeight1 )
33+
34+
if outputWidth1 > outputWidth2:
35+
o1 = Cropping2D( cropping=((0,0) , ( 0 , cx )), data_format='channels_first' )(o1)
36+
else:
37+
o2 = Cropping2D( cropping=((0,0) , ( 0 , cx )), data_format='channels_first' )(o2)
38+
39+
if outputHeight1 > outputHeight2 :
40+
o1 = Cropping2D( cropping=((0,cy) , ( 0 , 0 )), data_format='channels_first' )(o1)
41+
else:
42+
o2 = Cropping2D( cropping=((0, cy ) , ( 0 , 0 )), data_format='channels_first' )(o2)
43+
44+
return o1 , o2
45+
46+
def FCN8( nClasses , input_height=416, input_width=608 , vgg_level=3):
47+
48+
# assert input_height%32 == 0
49+
# assert input_width%32 == 0
50+
51+
# https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_th_dim_ordering_th_kernels.h5
52+
n_classes = 3
53+
img_input = Input(shape=(3,input_height,input_width))
54+
55+
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', data_format='channels_first' )(img_input)
56+
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', data_format='channels_first' )(x)
57+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool', data_format='channels_first' )(x)
58+
f1 = x
59+
# Block 2
60+
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', data_format='channels_first' )(x)
61+
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', data_format='channels_first' )(x)
62+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool', data_format='channels_first' )(x)
63+
f2 = x
64+
65+
# Block 3
66+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', data_format='channels_first' )(x)
67+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', data_format='channels_first' )(x)
68+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', data_format='channels_first' )(x)
69+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool', data_format='channels_first' )(x)
70+
f3 = x
71+
72+
# Block 4
73+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', data_format='channels_first' )(x)
74+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', data_format='channels_first' )(x)
75+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', data_format='channels_first' )(x)
76+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool', data_format='channels_first' )(x)
77+
f4 = x
78+
79+
# Block 5
80+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', data_format='channels_first' )(x)
81+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', data_format='channels_first' )(x)
82+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', data_format='channels_first' )(x)
83+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool', data_format='channels_first' )(x)
84+
f5 = x
85+
86+
x = Flatten(name='flatten')(x)
87+
x = Dense(4096, activation='relu', name='fc1')(x)
88+
x = Dense(4096, activation='relu', name='fc2')(x)
89+
x = Dense( 1024 , activation='softmax', name='predictions')(x)
90+
91+
vgg = Model( img_input , x )
92+
# vgg.load_weights(VGG_Weights_path)
93+
94+
o = f5
95+
96+
o = ( Conv2D( 4096 , ( 7 , 7 ) , activation='relu' , padding='same', data_format='channels_first'))(o)
97+
o = Dropout(0.5)(o)
98+
o = ( Conv2D( 4096 , ( 1 , 1 ) , activation='relu' , padding='same', data_format='channels_first'))(o)
99+
o = Dropout(0.5)(o)
100+
101+
o = ( Conv2D( nClasses , ( 1 , 1 ) ,kernel_initializer='he_normal' , data_format='channels_first'))(o)
102+
o = Conv2DTranspose( nClasses , kernel_size=(4,4) , strides=(2,2) , use_bias=False, data_format='channels_first' )(o)
103+
104+
o2 = f4
105+
o2 = ( Conv2D( nClasses , ( 1 , 1 ) ,kernel_initializer='he_normal' , data_format='channels_first'))(o2)
106+
107+
o , o2 = crop( o , o2 , img_input )
108+
109+
o = Add()([ o , o2 ])
110+
111+
o = Conv2DTranspose( nClasses , kernel_size=(4,4) , strides=(2,2) , use_bias=False, data_format='channels_first' )(o)
112+
o2 = f3
113+
o2 = ( Conv2D( nClasses , ( 1 , 1 ) ,kernel_initializer='he_normal' , data_format='channels_first'))(o2)
114+
o2 , o = crop( o2 , o , img_input )
115+
o = Add()([ o2 , o ])
116+
117+
118+
o = Conv2DTranspose( nClasses , kernel_size=(16,16) , strides=(8,8) , use_bias=False, data_format='channels_first' )(o)
119+
120+
o_shape = Model(img_input , o ).output_shape
121+
122+
outputHeight = o_shape[2]
123+
outputWidth = o_shape[3]
124+
125+
o = (Reshape(( -1 , outputHeight*outputWidth )))(o)
126+
o = (Permute((2, 1)))(o)
127+
o = (Activation('softmax'))(o)
128+
model = Model( img_input , o )
129+
model.outputWidth = outputWidth
130+
model.outputHeight = outputHeight
131+
132+
return model
133+
134+
135+
136+
if __name__ == '__main__':
137+
m = FCN8( 101 )
138+
from keras.utils import plot_model
139+
plot_model( m , show_shapes=True , to_file='model.png')

Models/Segnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
2+
# todo upgrade to keras 2.0
33

44

55
from keras.models import Sequential
@@ -22,7 +22,7 @@
2222

2323

2424

25-
def segnetModel(nClasses , optimizer=None , input_height=360, input_width=480 ):
25+
def segnet(nClasses , optimizer=None , input_height=360, input_width=480 ):
2626

2727
kernel = 3
2828
filter_size = 64

Models/Unet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22

3-
3+
# todo upgrade to keras 2.0
44

55
from keras.models import Sequential
66
from keras.layers import Reshape
@@ -23,7 +23,7 @@
2323

2424

2525

26-
def unet_2d (nClasses , optimizer=None , input_width=360 , input_height=480 , nChannels=1 ):
26+
def Unet (nClasses , optimizer=None , input_width=360 , input_height=480 , nChannels=1 ):
2727

2828
inputs = Input((nChannels, input_height, input_width))
2929
conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)

0 commit comments

Comments
 (0)