Skip to content
This repository was archived by the owner on Jun 23, 2025. It is now read-only.

Commit 1bc09d7

Browse files
committed
added layer_test & fixed invalid get_config() response
1 parent 87e4d0f commit 1bc09d7

File tree

2 files changed

+51
-37
lines changed

2 files changed

+51
-37
lines changed

keras_contrib/wrappers/cdropout.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ class ConcreteDropout(Wrapper):
2929
3030
# Arguments
3131
layer: The to be wrapped layer.
32-
n_data: int. Length of the dataset.
33-
length_scale: float. Prior lengthscale.
34-
model_precision: float. Model precision parameter is `1` for classification.
32+
n_data: int. `n_data > 0`.
33+
Length of the dataset.
34+
length_scale: float. `length_scale > 0`.
35+
Prior lengthscale.
36+
model_precision: float. `model_precision > 0`.
37+
Model precision parameter is `1` for classification.
3538
Also known as inverse observation noise.
36-
prob_init: Tuple[float, float].
39+
prob_init: Tuple[float, float]. `prob_init > 0`
3740
Probability lower / upper bounds of dropout rate initialization.
38-
temp: float. Temperature.
41+
temp: float. Temperature. `temp > 0`.
3942
Determines the speed of probability (i.e. dropout rate) adjustments.
4043
seed: Seed for random probability sampling.
4144
@@ -53,13 +56,23 @@ def __init__(self,
5356
seed=None,
5457
**kwargs):
5558
assert 'kernel_regularizer' not in kwargs
59+
assert n_data > 0 and isinstance(n_data, int)
60+
assert length_scale > 0.
61+
assert prob_init[0] <= prob_init[1] and prob_init[0] > 0.
62+
assert temp > 0.
63+
assert model_precision > 0.
5664
super(ConcreteDropout, self).__init__(layer, **kwargs)
57-
self.weight_regularizer = length_scale**2 / (model_precision * n_data)
58-
self.dropout_regularizer = 2 / (model_precision * n_data)
59-
self.prob_init = tuple(np.log(prob_init))
60-
self.temp = temp
61-
self.seed = seed
6265

66+
self._n_data = n_data
67+
self._length_scale = length_scale
68+
self._model_precision = model_precision
69+
self._prob_init = prob_init
70+
self._temp = temp
71+
self._seed = seed
72+
73+
eps = K.epsilon()
74+
self.weight_regularizer = length_scale**2 / (model_precision * n_data + eps)
75+
self.dropout_regularizer = 2 / (model_precision * n_data + eps)
6376
self.supports_masking = True
6477
self.p_logit = None
6578
self.p = None
@@ -84,15 +97,15 @@ def _concrete_dropout(self, inputs, layer_type):
8497
else:
8598
noise_shape = (noise_shape[0], 1, 1, noise_shape[3])
8699
unif_noise = K.random_uniform(shape=noise_shape,
87-
seed=self.seed,
100+
seed=self._seed,
88101
dtype=inputs.dtype)
89102
drop_prob = (
90103
K.log(self.p + eps)
91104
- K.log(1. - self.p + eps)
92105
+ K.log(unif_noise + eps)
93106
- K.log(1. - unif_noise + eps)
94107
)
95-
drop_prob = K.sigmoid(drop_prob / self.temp)
108+
drop_prob = K.sigmoid(drop_prob / self._temp)
96109

97110
# apply dropout
98111
random_tensor = 1. - drop_prob
@@ -123,8 +136,8 @@ def build(self, input_shape=None):
123136
self.p_logit = self.layer.add_weight(name='p_logit',
124137
shape=(1,),
125138
initializer=RandomUniform(
126-
*self.prob_init,
127-
seed=self.seed
139+
*np.log(self._prob_init),
140+
seed=self._seed
128141
),
129142
trainable=True)
130143
self.p = K.squeeze(K.sigmoid(self.p_logit), axis=0)
@@ -156,11 +169,12 @@ def relaxed_dropped_inputs():
156169
training=training)
157170

158171
def get_config(self):
159-
config = {'weight_regularizer': self.weight_regularizer,
160-
'dropout_regularizer': self.dropout_regularizer,
161-
'prob_init': tuple(np.round(self.prob_init, 8)),
162-
'temp': self.temp,
163-
'seed': self.seed}
172+
config = {'n_data': self._n_data,
173+
'length_scale': self._length_scale,
174+
'model_precision': self._model_precision,
175+
'prob_init': self._prob_init,
176+
'temp': self._temp,
177+
'seed': self._seed}
164178
base_config = super(ConcreteDropout, self).get_config()
165179
return dict(list(base_config.items()) + list(config.items()))
166180

tests/keras_contrib/wrappers/test_cdropout.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras import backend as K
99
from keras.layers import Input, Dense, Conv1D, Conv2D, Conv3D
1010
from keras.models import Model
11+
from keras_contrib.utils.test_utils import layer_test
1112
from keras_contrib.wrappers import ConcreteDropout
1213

1314

@@ -202,30 +203,29 @@ def sigmoid(x):
202203
assert_approx_equal(eval_loss, loss)
203204

204205

205-
def test_cdropout_1d_layer():
206+
@pytest.mark.parametrize('n_data', [1, 60])
207+
@pytest.mark.parametrize('layer, shape', [(Conv1D(8, 3), (None, 20, 1)),
208+
(Conv3D(16, 7), (1, 20, 20, 20, 1))])
209+
def test_cdropout_invalid_layer(layer, shape, n_data):
206210
"""To be replaced with a real function test, if implemented.
207211
"""
208-
in_dim = 20
209-
init_prop = .1
210-
211212
with pytest.raises(ValueError):
212-
inputs = Input(shape=(in_dim, 1,))
213-
ConcreteDropout(Conv1D(1, 3),
214-
in_dim,
215-
prob_init=(init_prop, init_prop))(inputs)
213+
layer_test(ConcreteDropout,
214+
kwargs={'layer': layer,
215+
'n_data': n_data},
216+
input_shape=shape)
216217

217218

218-
def test_cdropout_3d_layer():
219-
"""To be replaced with a real function test, if implemented.
219+
@pytest.mark.parametrize('n_data', [1, 60])
220+
@pytest.mark.parametrize('layer, shape', [(Conv2D(8, 3), (None, 12, 12, 3)),
221+
(Conv2D(16, 7), (1, 12, 12, 3))])
222+
def test_cdropout_valid_layer(layer, shape, n_data):
223+
"""Original layer test with valid parameters.
220224
"""
221-
in_dim = 20
222-
init_prop = .1
223-
224-
with pytest.raises(ValueError):
225-
inputs = Input(shape=(in_dim, in_dim, in_dim, 1,))
226-
ConcreteDropout(Conv3D(1, 3),
227-
in_dim,
228-
prob_init=(init_prop, init_prop))(inputs)
225+
layer_test(ConcreteDropout,
226+
kwargs={'layer': layer,
227+
'n_data': n_data},
228+
input_shape=shape)
229229

230230

231231
if __name__ == '__main__':

0 commit comments

Comments
 (0)