Skip to content
This repository was archived by the owner on Jun 23, 2025. It is now read-only.

Commit 4271e9c

Browse files
committed
refactored tests
1 parent 9007618 commit 4271e9c

File tree

2 files changed

+109
-19
lines changed

2 files changed

+109
-19
lines changed

keras_contrib/wrappers/cdropout.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.initializers import RandomUniform
77
from keras.layers import InputSpec
88
from keras.layers.wrappers import Wrapper
9+
from keras_contrib.utils.test_utils import to_tuple
910

1011

1112
class ConcreteDropout(Wrapper):
@@ -34,8 +35,9 @@ class ConcreteDropout(Wrapper):
3435
Also known as inverse observation noise.
3536
prob_init: Tuple[float, float].
3637
Probability lower / upper bounds of dropout rate initialization.
37-
temp: float. Temperature. Not used to be optimized.
38-
seed: Seed for random probability sampling.
38+
temp: float. Temperature.
39+
Determines the speed of probability adjustments.
40+
seed: Seed for random probability sampling.
3941
4042
# References
4143
- [Concrete Dropout](https://arxiv.org/pdf/1705.07832.pdf)
@@ -44,10 +46,10 @@ class ConcreteDropout(Wrapper):
4446
def __init__(self,
4547
layer,
4648
n_data,
47-
length_scale=2e-2,
49+
length_scale=5e-2,
4850
model_precision=1,
4951
prob_init=(0.1, 0.5),
50-
temp=0.1,
52+
temp=0.4,
5153
seed=None,
5254
**kwargs):
5355
assert 'kernel_regularizer' not in kwargs
@@ -64,7 +66,7 @@ def __init__(self,
6466

6567
def _concrete_dropout(self, inputs, layer_type):
6668
"""Applies concrete dropout.
67-
Used at training time (gradients can be propagated)
69+
Used at training time (gradients can be propagated).
6870
6971
# Arguments
7072
inputs: Input.
@@ -99,6 +101,7 @@ def _concrete_dropout(self, inputs, layer_type):
99101
return inputs
100102

101103
def build(self, input_shape=None):
104+
input_shape = to_tuple(input_shape)
102105
if len(input_shape) == 2: # Dense_layer
103106
input_dim = np.prod(input_shape[-1]) # we drop only last dim
104107
elif len(input_shape) == 4: # Conv_layer
@@ -126,7 +129,7 @@ def build(self, input_shape=None):
126129

127130
super(ConcreteDropout, self).build(input_shape)
128131

129-
# initialise regularizer / prior KL term
132+
# initialize regularizer / prior KL term
130133
weight = self.layer.kernel
131134
kernel_regularizer = (
132135
self.weight_regularizer

tests/keras_contrib/wrappers/test_cdropout.py

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,23 @@
55
from numpy.testing import assert_array_almost_equal
66
from numpy.testing import assert_approx_equal
77
from numpy.testing import assert_equal
8+
from keras import backend as K
89
from keras.layers import Input, Dense, Conv1D, Conv2D, Conv3D
910
from keras.models import Model
1011
from keras_contrib.wrappers import ConcreteDropout
1112

1213

13-
def test_cdropout():
14+
@pytest.fixture
15+
def clear_session_after_test():
16+
"""Overridden: make session cleanup manually.
17+
"""
18+
pass
19+
20+
21+
@pytest.fixture(scope='module')
22+
def dense_model():
23+
"""Initialize to be tested dense model. Executed once.
24+
"""
1425
# DATA
1526
in_dim = 20
1627
init_prop = .1
@@ -28,34 +39,74 @@ def test_cdropout():
2839
# Model, reference w/o Dropout
2940
x_ref = dense(inputs)
3041
model_ref = Model(inputs, x_ref)
31-
model_ref.compile(loss='mse', optimizer='rmsprop')
42+
model_ref.compile(loss=None, optimizer='rmsprop')
43+
44+
yield {'model': model,
45+
'model_ref': model_ref,
46+
'concrete_dropout': cd,
47+
'init_prop': init_prop,
48+
'in_dim': in_dim,
49+
'X': X}
50+
if K.backend() == 'tensorflow' or K.backend() == 'cntk':
51+
K.clear_session()
52+
53+
54+
def test_cdropout_dense_3rdweight(dense_model):
55+
"""Check about correct 3rd weight (equal to initial value)
56+
"""
57+
model = dense_model['model']
58+
init_prop = dense_model['init_prop']
3259

33-
# CHECKS
34-
# Check about correct 3rd weight (equal to initial value)
3560
W = model.get_weights()
3661
assert_array_almost_equal(W[2], [np.log(init_prop)])
3762

38-
# Check if ConcreteDropout in prediction phase is the same as no dropout
63+
64+
def test_cdropout_dense_identity(dense_model):
65+
"""Check if ConcreteDropout in prediction phase is the same as no dropout
66+
"""
67+
model = dense_model['model']
68+
model_ref = dense_model['model_ref']
69+
X = dense_model['X']
70+
3971
out = model.predict(X)
4072
out_ref = model_ref.predict(X)
4173
assert_allclose(out, out_ref, atol=1e-5)
4274

43-
# Check if ConcreteDropout has the right amount of losses deposited
75+
76+
def test_cdropout_dense_loss(dense_model):
77+
"""Check if ConcreteDropout has the right amount of losses deposited
78+
"""
79+
model = dense_model['model']
80+
4481
assert_equal(len(model.losses), 1)
4582

46-
# Check if the loss correspons the the desired value
83+
84+
def test_cdropout_dense_loss_value(dense_model):
85+
"""Check if the loss corresponds the the desired value
86+
"""
87+
model = dense_model['model']
88+
X = dense_model['X']
89+
cd = dense_model['concrete_dropout']
90+
in_dim = dense_model['in_dim']
91+
4792
def sigmoid(x):
4893
return 1. / (1. + np.exp(-x))
94+
95+
W = model.get_weights()
4996
p = np.squeeze(sigmoid(W[2]))
5097
kernel_regularizer = cd.weight_regularizer * np.sum(np.square(W[0])) / (1. - p)
5198
dropout_regularizer = (p * np.log(p) + (1. - p) * np.log(1. - p))
5299
dropout_regularizer *= cd.dropout_regularizer * in_dim
53100
loss = np.sum(kernel_regularizer + dropout_regularizer)
101+
54102
eval_loss = model.evaluate(X)
55103
assert_approx_equal(eval_loss, loss)
56104

57105

58-
def test_cdropout_conv():
106+
@pytest.fixture(scope='module')
107+
def conv2d_model():
108+
"""Initialize to be tested conv model. Executed once.
109+
"""
59110
# DATA
60111
in_dim = 20
61112
init_prop = .1
@@ -75,27 +126,63 @@ def test_cdropout_conv():
75126
model_ref = Model(inputs, x_ref)
76127
model_ref.compile(loss=None, optimizer='rmsprop')
77128

78-
# CHECKS
79-
# Check about correct 3rd weight (equal to initial value)
129+
yield {'model': model,
130+
'model_ref': model_ref,
131+
'concrete_dropout': cd,
132+
'init_prop': init_prop,
133+
'in_dim': in_dim,
134+
'X': X}
135+
if K.backend() == 'tensorflow' or K.backend() == 'cntk':
136+
K.clear_session()
137+
138+
139+
def test_cdropout_conv2d_3rdweight(conv2d_model):
140+
"""Check about correct 3rd weight (equal to initial value)
141+
"""
142+
model = conv2d_model['model']
143+
init_prop = conv2d_model['init_prop']
144+
80145
W = model.get_weights()
81146
assert_array_almost_equal(W[2], [np.log(init_prop)])
82147

83-
# Check if ConcreteDropout in prediction phase is the same as no dropout
148+
149+
def test_cdropout_conv2d_identity(conv2d_model):
150+
"""Check if ConcreteDropout in prediction phase is the same as no dropout
151+
"""
152+
model = conv2d_model['model']
153+
model_ref = conv2d_model['model_ref']
154+
X = conv2d_model['X']
155+
84156
out = model.predict(X)
85157
out_ref = model_ref.predict(X)
86158
assert_allclose(out, out_ref, atol=1e-5)
87159

88-
# Check if ConcreteDropout has the right amount of losses deposited
160+
161+
def test_cdropout_conv2d_loss(conv2d_model):
162+
"""Check if ConcreteDropout has the right amount of losses deposited
163+
"""
164+
model = conv2d_model['model']
165+
89166
assert_equal(len(model.losses), 1)
90167

91-
# Check if the loss correspons the the desired value
168+
169+
def test_cdropout_conv2d_loss_value(conv2d_model):
170+
"""Check if the loss corresponds the the desired value
171+
"""
172+
model = conv2d_model['model']
173+
X = conv2d_model['X']
174+
cd = conv2d_model['concrete_dropout']
175+
92176
def sigmoid(x):
93177
return 1. / (1. + np.exp(-x))
178+
179+
W = model.get_weights()
94180
p = np.squeeze(sigmoid(W[2]))
95181
kernel_regularizer = cd.weight_regularizer * np.sum(np.square(W[0])) / (1. - p)
96182
dropout_regularizer = (p * np.log(p) + (1. - p) * np.log(1. - p))
97183
dropout_regularizer *= cd.dropout_regularizer * 1 # only channels are dropped
98184
loss = np.sum(kernel_regularizer + dropout_regularizer)
185+
99186
eval_loss = model.evaluate(X)
100187
assert_approx_equal(eval_loss, loss)
101188

0 commit comments

Comments
 (0)