Skip to content
This repository was archived by the owner on Jun 23, 2025. It is now read-only.

Commit 088698a

Browse files
committed
further test coverage increasement
1 parent 4271e9c commit 088698a

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

tests/keras_contrib/wrappers/test_cdropout.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,33 @@ def sigmoid(x):
103103
assert_approx_equal(eval_loss, loss)
104104

105105

106-
@pytest.fixture(scope='module')
107-
def conv2d_model():
108-
"""Initialize to be tested conv model. Executed once.
106+
@pytest.fixture(scope='module', params=['channels_first', 'channels_last'])
107+
def conv2d_model(request):
108+
"""Initialize to be tested conv model. Executed once per param:
109+
The whole tests are repeated for respectively
110+
`channels_first` and `channels_last`.
109111
"""
112+
assert request.param in {'channels_last', 'channels_first'}
113+
K.set_image_data_format(request.param)
114+
110115
# DATA
111116
in_dim = 20
112117
init_prop = .1
113118
np.random.seed(1)
114-
X = np.random.randn(1, in_dim, in_dim, 1)
119+
if K.image_data_format() == 'channels_last':
120+
X = np.random.randn(1, in_dim, in_dim, 1)
121+
elif K.image_data_format() == 'channels_first':
122+
X = np.random.randn(1, 1, in_dim, in_dim)
123+
else:
124+
raise ValueError('Unknown data_format:', K.image_data_format())
115125

116126
# MODEL
117-
inputs = Input(shape=(in_dim, in_dim, 1,))
127+
if K.image_data_format() == 'channels_last':
128+
inputs = Input(shape=(in_dim, in_dim, 1,))
129+
elif K.image_data_format() == 'channels_first':
130+
inputs = Input(shape=(1, in_dim, in_dim,))
131+
else:
132+
raise ValueError('Unknown data_format:', K.image_data_format())
118133
conv2d = Conv2D(1, (3, 3))
119134
# Model, normal
120135
cd = ConcreteDropout(conv2d, in_dim, prob_init=(init_prop, init_prop))

0 commit comments

Comments
 (0)