Skip to content
This repository was archived by the owner on Nov 3, 2022. It is now read-only.

Commit 54c988a

Browse files
committed
added new test cases for coverage
1 parent b0edd69 commit 54c988a

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

keras_contrib/wrappers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
from .cdropout import ConcreteDropout
3+
from .cdropout import ConcreteDropout

keras_contrib/wrappers/cdropout.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from __future__ import absolute_import
33

44
import numpy as np
5-
65
from keras import backend as K
7-
from keras.engine import InputSpec
86
from keras.initializers import RandomUniform
7+
from keras.layers import InputSpec
98
from keras.layers.wrappers import Wrapper
109

1110

@@ -34,7 +33,7 @@ class ConcreteDropout(Wrapper):
3433
model_precision: float. Model precision parameter is `1` for classification.
3534
Also known as inverse observation noise.
3635
prob_init: Tuple[float, float].
37-
Probability lower / upper bounds of dropout rate initialization.
36+
Probability lower / upper bounds of dropout rate initialization.
3837
temp: float. Temperature. Not used to be optimized.
3938
seed: Seed for random probability sampling.
4039
@@ -156,7 +155,7 @@ def relaxed_dropped_inputs():
156155
def get_config(self):
157156
config = {'weight_regularizer': self.weight_regularizer,
158157
'dropout_regularizer': self.dropout_regularizer,
159-
'prob_init': self.prob_init,
158+
'prob_init': tuple(np.round(self.prob_init, 8)),
160159
'temp': self.temp,
161160
'seed': self.seed}
162161
base_config = super(ConcreteDropout, self).get_config()

tests/keras_contrib/wrappers/test_cdropout.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from keras_contrib.wrappers import ConcreteDropout
1212

13-
1413
def test_cdropout():
1514
# Data
1615
in_dim = 20
@@ -20,7 +19,7 @@ def test_cdropout():
2019

2120
# Model
2221
inputs = Input(shape=(in_dim,))
23-
dense = Dense(1, use_bias=True, input_shape=(in_dim,))
22+
dense = Dense(1, use_bias=True)
2423
# Model, normal
2524
cd = ConcreteDropout(dense, in_dim, prob_init=(init_prop, init_prop))
2625
x = cd(inputs)
@@ -53,6 +52,56 @@ def sigmoid(x):
5352
loss = np.sum(kernel_regularizer + dropout_regularizer)
5453
eval_loss = model.evaluate(X)
5554
assert_approx_equal(eval_loss, loss)
55+
56+
def test_cdropout_conv():
57+
# Data
58+
in_dim = 20
59+
init_prop = .1
60+
np.random.seed(1)
61+
X = np.random.randn(1, in_dim, in_dim, 1)
62+
63+
# Model
64+
inputs = Input(shape=(in_dim, in_dim, 1,))
65+
conv2d = Conv2D(1, (3,3))
66+
# Model, normal
67+
cd = ConcreteDropout(conv2d, in_dim, prob_init=(init_prop, init_prop))
68+
x = cd(inputs)
69+
model = Model(inputs, x)
70+
model.compile(loss=None, optimizer='rmsprop')
71+
# Model, reference w/o Dropout
72+
x_ref = conv2d(inputs)
73+
model_ref = Model(inputs, x_ref)
74+
model_ref.compile(loss=None, optimizer='rmsprop')
75+
76+
# Check about correct 3rd weight (equal to initial value)
77+
W = model.get_weights()
78+
assert_array_almost_equal(W[2], [np.log(init_prop)])
79+
80+
# Check if ConcreteDropout in prediction phase is the same as no dropout
81+
out = model.predict(X)
82+
out_ref = model_ref.predict(X)
83+
assert_allclose(out, out_ref, atol=1e-5)
84+
85+
# Check if ConcreteDropout has the right amount of losses deposited
86+
assert_equal(len(model.losses), 1)
87+
88+
# Check if the loss correspons the the desired value
89+
def sigmoid(x):
90+
return 1. / (1. + np.exp(-x))
91+
p = np.squeeze(sigmoid(W[2]))
92+
kernel_regularizer = cd.weight_regularizer * np.sum(np.square(W[0])) / (1. - p)
93+
dropout_regularizer = (p * np.log(p) + (1. - p) * np.log(1. - p))
94+
dropout_regularizer *= cd.dropout_regularizer * 1 # because only channels are dropped
95+
loss = np.sum(kernel_regularizer + dropout_regularizer)
96+
eval_loss = model.evaluate(X)
97+
assert_approx_equal(eval_loss, loss)
98+
99+
def test_cdropout_wrong_layertype():
100+
"""To be replaced with a real function test, if implemented.
101+
"""
102+
with pytest.raises(ValueError):
103+
inputs = Input(shape=(in_dim, in_dim,))
104+
cd = ConcreteDropout(Conv1D(1, 3), in_dim, prob_init=(init_prop, init_prop))(inputs)
56105

57106

58107
if __name__ == '__main__':

0 commit comments

Comments
 (0)