@@ -36,7 +36,7 @@ class ConcreteDropout(Wrapper):
36
36
prob_init: Tuple[float, float].
37
37
Probability lower / upper bounds of dropout rate initialization.
38
38
temp: float. Temperature.
39
- Determines the speed of probability adjustments.
39
+ Determines the speed of probability (i.e. dropout rate) adjustments.
40
40
seed: Seed for random probability sampling.
41
41
42
42
# References
@@ -74,6 +74,7 @@ def _concrete_dropout(self, inputs, layer_type):
74
74
# Returns
75
75
A tensor with the same shape as inputs and dropout applied.
76
76
"""
77
+ assert layer_type in {'dense' , 'conv2d' }
77
78
eps = K .cast_to_floatx (K .epsilon ())
78
79
79
80
noise_shape = K .shape (inputs )
@@ -93,6 +94,7 @@ def _concrete_dropout(self, inputs, layer_type):
93
94
)
94
95
drop_prob = K .sigmoid (drop_prob / self .temp )
95
96
97
+ # apply dropout
96
98
random_tensor = 1. - drop_prob
97
99
retain_prob = 1. - self .p
98
100
inputs *= random_tensor
@@ -104,7 +106,7 @@ def build(self, input_shape=None):
104
106
input_shape = to_tuple (input_shape )
105
107
if len (input_shape ) == 2 : # Dense_layer
106
108
input_dim = np .prod (input_shape [- 1 ]) # we drop only last dim
107
- elif len (input_shape ) == 4 : # Conv_layer
109
+ elif len (input_shape ) == 4 : # Conv2D_layer
108
110
input_dim = (input_shape [1 ]
109
111
if K .image_data_format () == 'channels_first'
110
112
else input_shape [3 ]) # we drop only channels
@@ -129,7 +131,7 @@ def build(self, input_shape=None):
129
131
130
132
super (ConcreteDropout , self ).build (input_shape )
131
133
132
- # initialize regularizer / prior KL term
134
+ # initialize regularizer / prior KL term and add to layer-loss
133
135
weight = self .layer .kernel
134
136
kernel_regularizer = (
135
137
self .weight_regularizer
@@ -146,9 +148,7 @@ def build(self, input_shape=None):
146
148
def call (self , inputs , training = None ):
147
149
def relaxed_dropped_inputs ():
148
150
return self .layer .call (self ._concrete_dropout (inputs , (
149
- 'dense'
150
- if len (K .int_shape (inputs )) == 2
151
- else 'conv2d'
151
+ 'dense' if len (K .int_shape (inputs )) == 2 else 'conv2d'
152
152
)))
153
153
154
154
return K .in_train_phase (relaxed_dropped_inputs ,
0 commit comments