|
1 | 1 | import matplotlib as mpl
|
2 | 2 | from keras.engine.training import Model
|
3 |
| -from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D, Deconvolution2D |
| 3 | +from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D, Deconvolution2D, ArbitraryDeconvolution2D |
4 | 4 | from keras.layers import merge, Dense, Dropout, Flatten, Input, Activation, BatchNormalization
|
5 | 5 | from keras.layers.advanced_activations import PReLU
|
6 | 6 | from keras.models import Sequential, model_from_json
|
7 | 7 | from keras.optimizers import SGD
|
8 | 8 | from keras.regularizers import l2
|
9 | 9 | from keras.utils import np_utils
|
10 | 10 | from keras.utils.layer_utils import print_summary
|
| 11 | +from keras import backend as K |
11 | 12 |
|
12 | 13 | from keras_wrapper.dataset import Data_Batch_Generator, Homogeneous_Data_Batch_Generator
|
13 | 14 | from keras_wrapper.deprecated.thread_loader import ThreadDataLoader, retrieveXY
|
@@ -3018,15 +3019,20 @@ def add_dense_block(self, in_layer, nb_layers, k, drop, init_weights):
|
3018 | 3019 | :param init_weights: weights initialization function
|
3019 | 3020 | :return: output layer of the dense block
|
3020 | 3021 | """
|
| 3022 | + if K.image_dim_ordering() == 'tf': |
| 3023 | + axis = -1 |
| 3024 | + else: |
| 3025 | + axis = 1 |
| 3026 | + |
3021 | 3027 | list_outputs = []
|
3022 | 3028 | prev_layer = in_layer
|
3023 | 3029 | for n in range(nb_layers):
|
3024 | 3030 | # Insert dense layer
|
3025 | 3031 | new_layer = self.add_dense_layer(prev_layer, k, drop, init_weights)
|
3026 | 3032 | list_outputs.append(new_layer)
|
3027 | 3033 | # Merge with previous layer
|
3028 |
| - prev_layer = merge([new_layer, prev_layer], mode='concat', concat_axis=1) |
3029 |
| - return merge(list_outputs, mode='concat', concat_axis=1) |
| 3034 | + prev_layer = merge([new_layer, prev_layer], mode='concat', concat_axis=axis) |
| 3035 | + return merge(list_outputs, mode='concat', concat_axis=axis) |
3030 | 3036 |
|
3031 | 3037 |
|
3032 | 3038 | def add_dense_layer(self, in_layer, k, drop, init_weights):
|
@@ -3080,11 +3086,16 @@ def add_transitiondown_block(self, x, skip_dim,
|
3080 | 3086 |
|
3081 | 3087 | :return: [output layer, skip connection name]
|
3082 | 3088 | """
|
| 3089 | + if K.image_dim_ordering() == 'tf': |
| 3090 | + axis = -1 |
| 3091 | + else: |
| 3092 | + axis = 1 |
| 3093 | + |
3083 | 3094 | # Dense Block
|
3084 | 3095 | x_dense = self.add_dense_block(x, nb_layers, growth, drop, init_weights) # (growth*nb_layers) feature maps added
|
3085 | 3096 |
|
3086 | 3097 | ## Concatenation and skip connection recovery for upsampling path
|
3087 |
| - skip = merge([x, x_dense], mode='concat', concat_axis=1, name='down_skip_'+str(skip_dim)) |
| 3098 | + skip = merge([x, x_dense], mode='concat', concat_axis=axis, name='down_skip_'+str(skip_dim)) |
3088 | 3099 | #skip = x_dense
|
3089 | 3100 |
|
3090 | 3101 | # Transition Down
|
@@ -3126,14 +3137,22 @@ def add_transitionup_block(self, x, skip_conn, skip_conn_shapes,
|
3126 | 3137 |
|
3127 | 3138 | :return: output layer
|
3128 | 3139 | """
|
| 3140 | + if K.image_dim_ordering() == 'tf': |
| 3141 | + axis = -1 |
| 3142 | + else: |
| 3143 | + axis = 1 |
| 3144 | + |
3129 | 3145 | # Transition Up
|
3130 |
| - x = Deconvolution2D(nb_filters_deconv, 3, 3, init=init_weights, |
3131 |
| - output_shape=tuple([None, nb_filters_deconv]+out_dim), |
3132 |
| - subsample=(2, 2), border_mode='same')(x) |
| 3146 | + x = ArbitraryDeconvolution2D(nb_filters_deconv, 3, 3, init=init_weights, |
| 3147 | + subsample=(2, 2), border_mode='same')(x) |
| 3148 | + # x = Deconvolution2D(nb_filters_deconv, 3, 3, init=init_weights, |
| 3149 | + # output_shape=tuple([None, nb_filters_deconv]+out_dim), |
| 3150 | + # subsample=(2, 2), border_mode='same')(x) |
| 3151 | + |
3133 | 3152 | # Skip connection concatenation
|
3134 | 3153 | if out_dim in skip_conn_shapes:
|
3135 | 3154 | skip = skip_conn[skip_conn_shapes.index(out_dim)]
|
3136 |
| - x = merge([skip, x], mode='concat', concat_axis=1, name='skip_'+str(out_dim)) |
| 3155 | + x = merge([skip, x], mode='concat', concat_axis=axis, name='skip_'+str(out_dim)) |
3137 | 3156 | # Dense Block
|
3138 | 3157 | x = self.add_dense_block(x, nb_layers, growth, drop, init_weights) # (growth*nb_layers) feature maps added
|
3139 | 3158 | return x
|
|
0 commit comments