Skip to content

Commit 9fa714c

Browse files
committed
ArbitraryDeconvolution2D in transition up block
1 parent 3337cb2 commit 9fa714c

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

keras_wrapper/cnn_model.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import matplotlib as mpl
22
from keras.engine.training import Model
3-
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D, Deconvolution2D
3+
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D, Deconvolution2D, ArbitraryDeconvolution2D
44
from keras.layers import merge, Dense, Dropout, Flatten, Input, Activation, BatchNormalization
55
from keras.layers.advanced_activations import PReLU
66
from keras.models import Sequential, model_from_json
77
from keras.optimizers import SGD
88
from keras.regularizers import l2
99
from keras.utils import np_utils
1010
from keras.utils.layer_utils import print_summary
11+
from keras import backend as K
1112

1213
from keras_wrapper.dataset import Data_Batch_Generator, Homogeneous_Data_Batch_Generator
1314
from keras_wrapper.deprecated.thread_loader import ThreadDataLoader, retrieveXY
@@ -3018,15 +3019,20 @@ def add_dense_block(self, in_layer, nb_layers, k, drop, init_weights):
30183019
:param init_weights: weights initialization function
30193020
:return: output layer of the dense block
30203021
"""
3022+
if K.image_dim_ordering() == 'tf':
3023+
axis = -1
3024+
else:
3025+
axis = 1
3026+
30213027
list_outputs = []
30223028
prev_layer = in_layer
30233029
for n in range(nb_layers):
30243030
# Insert dense layer
30253031
new_layer = self.add_dense_layer(prev_layer, k, drop, init_weights)
30263032
list_outputs.append(new_layer)
30273033
# Merge with previous layer
3028-
prev_layer = merge([new_layer, prev_layer], mode='concat', concat_axis=1)
3029-
return merge(list_outputs, mode='concat', concat_axis=1)
3034+
prev_layer = merge([new_layer, prev_layer], mode='concat', concat_axis=axis)
3035+
return merge(list_outputs, mode='concat', concat_axis=axis)
30303036

30313037

30323038
def add_dense_layer(self, in_layer, k, drop, init_weights):
@@ -3080,11 +3086,16 @@ def add_transitiondown_block(self, x, skip_dim,
30803086
30813087
:return: [output layer, skip connection name]
30823088
"""
3089+
if K.image_dim_ordering() == 'tf':
3090+
axis = -1
3091+
else:
3092+
axis = 1
3093+
30833094
# Dense Block
30843095
x_dense = self.add_dense_block(x, nb_layers, growth, drop, init_weights) # (growth*nb_layers) feature maps added
30853096

30863097
## Concatenation and skip connection recovery for upsampling path
3087-
skip = merge([x, x_dense], mode='concat', concat_axis=1, name='down_skip_'+str(skip_dim))
3098+
skip = merge([x, x_dense], mode='concat', concat_axis=axis, name='down_skip_'+str(skip_dim))
30883099
#skip = x_dense
30893100

30903101
# Transition Down
@@ -3126,14 +3137,22 @@ def add_transitionup_block(self, x, skip_conn, skip_conn_shapes,
31263137
31273138
:return: output layer
31283139
"""
3140+
if K.image_dim_ordering() == 'tf':
3141+
axis = -1
3142+
else:
3143+
axis = 1
3144+
31293145
# Transition Up
3130-
x = Deconvolution2D(nb_filters_deconv, 3, 3, init=init_weights,
3131-
output_shape=tuple([None, nb_filters_deconv]+out_dim),
3132-
subsample=(2, 2), border_mode='same')(x)
3146+
x = ArbitraryDeconvolution2D(nb_filters_deconv, 3, 3, init=init_weights,
3147+
subsample=(2, 2), border_mode='same')(x)
3148+
# x = Deconvolution2D(nb_filters_deconv, 3, 3, init=init_weights,
3149+
# output_shape=tuple([None, nb_filters_deconv]+out_dim),
3150+
# subsample=(2, 2), border_mode='same')(x)
3151+
31333152
# Skip connection concatenation
31343153
if out_dim in skip_conn_shapes:
31353154
skip = skip_conn[skip_conn_shapes.index(out_dim)]
3136-
x = merge([skip, x], mode='concat', concat_axis=1, name='skip_'+str(out_dim))
3155+
x = merge([skip, x], mode='concat', concat_axis=axis, name='skip_'+str(out_dim))
31373156
# Dense Block
31383157
x = self.add_dense_block(x, nb_layers, growth, drop, init_weights) # (growth*nb_layers) feature maps added
31393158
return x

0 commit comments

Comments
 (0)