Skip to content
This repository was archived by the owner on Jun 23, 2025. It is now read-only.

Commit 30c45e2

Browse files
committed
finally made cdroupout for keras-contrib available
1 parent fff2642 commit 30c45e2

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

keras_contrib/wrappers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from __future__ import absolute_import
2+
3+
from .cdropout import ConcreteDropout

keras_contrib/wrappers/cdropout.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import numpy as np
4+
5+
from keras import backend as K
6+
from keras.engine import InputSpec
7+
from keras.initializers import RandomUniform
8+
from keras.layers import Wrapper
9+
10+
11+
class ConcreteDropout(Wrapper):
12+
"""A wrapper automating the dropout rate choice
13+
through the 'Concrete Dropout' technique.
14+
15+
# Example
16+
17+
```python
18+
# as first layer in a sequential model:
19+
model = Sequential()
20+
model.add(ConcreteDropout(Dense(8), input_shape=(16)), n_data=5000)
21+
# now model.output_shape == (None, 8)
22+
# subsequent layers: no need for input shape
23+
model.add(ConcreteDropout(Dense(32), n_data=500))
24+
# now model.output_shape == (None, 32)
25+
26+
# Note that the current implementation supports Conv2D Layer as well.
27+
```
28+
29+
# Arguments
30+
layer: The to be wrapped layer.
31+
n_data: int. Length of the dataset.
32+
length_scale: float. Prior lengthscale.
33+
model_precision: float. Model precision parameter is `1` for classification.
34+
Also known as inverse observation noise.
35+
prob_init: Tuple[float, float].
36+
Probability lower / upper bounds of dropout rate initialization.
37+
temp: float. Temperature. Not used to be optimized.
38+
seed: Seed for random probability sampling.
39+
40+
# References
41+
- [Concrete Dropout](https://arxiv.org/pdf/1705.07832.pdf)
42+
"""
43+
44+
def __init__(self,
45+
layer,
46+
n_data,
47+
length_scale=2e-2,
48+
model_precision=1,
49+
prob_init=(0.1, 0.5),
50+
temp=0.1,
51+
seed=None,
52+
**kwargs):
53+
assert 'kernel_regularizer' not in kwargs
54+
super(ConcreteDropout, self).__init__(layer, **kwargs)
55+
self.weight_regularizer = length_scale**2 / (model_precision * n_data)
56+
self.dropout_regularizer = 2 / (model_precision * n_data)
57+
self.prob_init = tuple(np.log(prob_init))
58+
self.temp = temp
59+
self.seed = seed
60+
61+
self.supports_masking = True
62+
self.p_logit = None
63+
self.p = None
64+
65+
def _concrete_dropout(self, inputs, layer_type):
66+
"""Applies concrete dropout.
67+
Used at training time (gradients can be propagated)
68+
69+
# Arguments
70+
inputs: Input.
71+
layer_type: str. Either 'dense' or 'conv2d'.
72+
# Returns
73+
A tensor with the same shape as inputs and dropout applied.
74+
"""
75+
eps = K.cast_to_floatx(K.epsilon())
76+
77+
noise_shape = K.shape(inputs)
78+
if layer_type == 'conv2d':
79+
if K.image_data_format() == 'channels_first':
80+
noise_shape = (noise_shape[0], noise_shape[1], 1, 1)
81+
else:
82+
noise_shape = (noise_shape[0], 1, 1, noise_shape[3])
83+
unif_noise = K.random_uniform(shape=noise_shape,
84+
seed=self.seed,
85+
dtype=inputs.dtype)
86+
drop_prob = (
87+
K.log(self.p + eps)
88+
- K.log(1. - self.p + eps)
89+
+ K.log(unif_noise + eps)
90+
- K.log(1. - unif_noise + eps)
91+
)
92+
drop_prob = K.sigmoid(drop_prob / self.temp)
93+
94+
random_tensor = 1. - drop_prob
95+
retain_prob = 1. - self.p
96+
inputs *= random_tensor
97+
inputs /= retain_prob
98+
99+
return inputs
100+
101+
def build(self, input_shape=None):
102+
if len(input_shape) == 2: # Dense_layer
103+
input_dim = np.prod(input_shape[-1]) # we drop only last dim
104+
elif len(input_shape) == 4: # Conv_layer
105+
input_dim = (input_shape[1]
106+
if K.image_data_format() == 'channels_first'
107+
else input_shape[3]) # we drop only channels
108+
else:
109+
raise ValueError(
110+
'concrete_dropout currenty supports only Dense/Conv2D layers')
111+
112+
self.input_spec = InputSpec(shape=input_shape)
113+
if not self.layer.built:
114+
self.layer.build(input_shape)
115+
self.layer.built = True
116+
117+
# initialise p
118+
self.p_logit = self.layer.add_weight(name='p_logit',
119+
shape=(1,),
120+
initializer=RandomUniform(
121+
*self.prob_init,
122+
seed=self.seed
123+
),
124+
trainable=True)
125+
self.p = K.squeeze(K.sigmoid(self.p_logit), axis=0)
126+
127+
super(ConcreteDropout, self).build(input_shape)
128+
129+
# initialise regularizer / prior KL term
130+
weight = self.layer.kernel
131+
kernel_regularizer = (
132+
self.weight_regularizer
133+
* K.sum(K.square(weight))
134+
/ (1. - self.p)
135+
)
136+
dropout_regularizer = (
137+
self.p * K.log(self.p)
138+
+ (1. - self.p) * K.log(1. - self.p)
139+
) * self.dropout_regularizer * input_dim
140+
regularizer = K.sum(kernel_regularizer + dropout_regularizer)
141+
self.layer.add_loss(regularizer)
142+
143+
def call(self, inputs, training=None):
144+
def relaxed_dropped_inputs():
145+
return self.layer.call(self._concrete_dropout(inputs, (
146+
'dense'
147+
if len(K.int_shape(inputs)) == 2
148+
else 'conv2d'
149+
)))
150+
151+
return K.in_train_phase(relaxed_dropped_inputs,
152+
self.layer.call(inputs),
153+
training=training)
154+
155+
def get_config(self):
156+
config = {'weight_regularizer': self.weight_regularizer,
157+
'dropout_regularizer': self.dropout_regularizer,
158+
'prob_init': self.prob_init,
159+
'temp': self.temp,
160+
'seed': self.seed}
161+
base_config = super(ConcreteDropout, self).get_config()
162+
return dict(list(base_config.items()) + list(config.items()))
163+
164+
def compute_output_shape(self, input_shape):
165+
return self.layer.compute_output_shape(input_shape)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
import numpy as np
3+
4+
from keras.layers import Input, Dense
5+
from keras.models import Model
6+
from numpy.testing import assert_allclose
7+
from numpy.testing import assert_array_almost_equal
8+
from numpy.testing import assert_approx_equal
9+
from numpy.testing import assert_equal
10+
11+
from keras_contrib.wrappers import ConcreteDropout
12+
13+
14+
def test_cdropout():
15+
# Data
16+
in_dim = 20
17+
init_prop = .1
18+
np.random.seed(1)
19+
X = np.random.randn(1, in_dim)
20+
21+
# Model
22+
inputs = Input(shape=(in_dim,))
23+
dense = Dense(1, use_bias=True, input_shape=(in_dim,))
24+
# Model, normal
25+
cd = ConcreteDropout(dense, in_dim, prob_init=(init_prop, init_prop))
26+
x = cd(inputs)
27+
model = Model(inputs, x)
28+
model.compile(loss=None, optimizer='rmsprop')
29+
# Model, reference w/o Dropout
30+
x_ref = dense(inputs)
31+
model_ref = Model(inputs, x_ref)
32+
model_ref.compile(loss='mse', optimizer='rmsprop')
33+
34+
# Check about correct 3rd weight (equal to initial value)
35+
W = model.get_weights()
36+
assert_array_almost_equal(W[2], [np.log(init_prop)])
37+
38+
# Check if ConcreteDropout in prediction phase is the same as no dropout
39+
out = model.predict(X)
40+
out_ref = model_ref.predict(X)
41+
assert_allclose(out, out_ref, atol=1e-5)
42+
43+
# Check if ConcreteDropout has the right amount of losses deposited
44+
assert_equal(len(model.losses), 1)
45+
46+
# Check if the loss correspons the the desired value
47+
def sigmoid(x):
48+
return 1. / (1. + np.exp(-x))
49+
p = np.squeeze(sigmoid(W[2]))
50+
kernel_regularizer = cd.weight_regularizer * np.sum(np.square(W[0])) / (1. - p)
51+
dropout_regularizer = (p * np.log(p) + (1. - p) * np.log(1. - p))
52+
dropout_regularizer *= cd.dropout_regularizer * in_dim
53+
loss = np.sum(kernel_regularizer + dropout_regularizer)
54+
eval_loss = model.evaluate(X)
55+
assert_approx_equal(eval_loss, loss)
56+
57+
58+
if __name__ == '__main__':
59+
pytest.main([__file__])

0 commit comments

Comments
 (0)