@@ -103,18 +103,33 @@ def sigmoid(x):
103
103
assert_approx_equal (eval_loss , loss )
104
104
105
105
106
- @pytest .fixture (scope = 'module' )
107
- def conv2d_model ():
108
- """Initialize to be tested conv model. Executed once.
106
+ @pytest .fixture (scope = 'module' , params = ['channels_first' , 'channels_last' ])
107
+ def conv2d_model (request ):
108
+ """Initialize to be tested conv model. Executed once per param:
109
+ The whole tests are repeated for respectively
110
+ `channels_first` and `channels_last`.
109
111
"""
112
+ assert request .param in {'channels_last' , 'channels_first' }
113
+ K .set_image_data_format (request .param )
114
+
110
115
# DATA
111
116
in_dim = 20
112
117
init_prop = .1
113
118
np .random .seed (1 )
114
- X = np .random .randn (1 , in_dim , in_dim , 1 )
119
+ if K .image_data_format () == 'channels_last' :
120
+ X = np .random .randn (1 , in_dim , in_dim , 1 )
121
+ elif K .image_data_format () == 'channels_first' :
122
+ X = np .random .randn (1 , 1 , in_dim , in_dim )
123
+ else :
124
+ raise ValueError ('Unknown data_format:' , K .image_data_format ())
115
125
116
126
# MODEL
117
- inputs = Input (shape = (in_dim , in_dim , 1 ,))
127
+ if K .image_data_format () == 'channels_last' :
128
+ inputs = Input (shape = (in_dim , in_dim , 1 ,))
129
+ elif K .image_data_format () == 'channels_first' :
130
+ inputs = Input (shape = (1 , in_dim , in_dim ,))
131
+ else :
132
+ raise ValueError ('Unknown data_format:' , K .image_data_format ())
118
133
conv2d = Conv2D (1 , (3 , 3 ))
119
134
# Model, normal
120
135
cd = ConcreteDropout (conv2d , in_dim , prob_init = (init_prop , init_prop ))
0 commit comments