@@ -2,7 +2,9 @@ istraining() = false
2
2
3
3
ChainRulesCore. rrule (:: typeof (istraining)) = true , _ -> (NoTangent (),)
4
4
5
- _isactive (m) = isnothing (m. active) ? istraining () : m. active
5
+ _isactive (m) = isnothing (m. active) ? istraining () : Bool (m. active)
6
+
7
+ ChainRulesCore. @non_differentiable _isactive (:: Any )
6
8
7
9
_dropout_shape (s, :: Colon ) = size (s)
8
10
_dropout_shape (s, dims) = tuple ((i ∉ dims ? 1 : si for (i, si) ∈ enumerate (size (s))). .. )
@@ -29,26 +31,51 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
29
31
30
32
The [`Dropout`](@ref) layer is what you should use in most scenarios.
31
33
"""
32
- function dropout (rng, x, p; dims= :, active:: Bool = true )
33
- active || return x
34
- y = dropout_mask (rng, x, p, dims= dims)
35
- return x .* y
36
- end
34
+ dropout (rng, x, p; dims= :, active:: Bool = true ) = _dropout (rng, x, p, dims, active)
37
35
dropout (x, p; kwargs... ) = dropout (rng_from_array (x), x, p; kwargs... )
38
36
39
- dropout_mask (rng:: CUDA.RNG , x:: CuArray , p; kwargs... ) = _dropout_mask (rng, x, p; kwargs... )
40
- dropout_mask (rng, x:: CuArray , p; kwargs... ) =
41
- throw (ArgumentError (" x isa CuArray, but rng isa $(typeof (rng)) . dropout_mask only support CUDA.RNG for CuArrays." ))
42
- dropout_mask (rng, x, p; kwargs... ) = _dropout_mask (rng, x, p; kwargs... )
43
- function _dropout_mask (rng, x, p; dims= :)
37
+ # Internal function without kwargs to keep Zygote generated code type stable
38
+ function _dropout (rng, x, p, dims, active)
39
+ mask = active ? dropout_mask (rng, x, p, dims) : nothing
40
+ return _apply_mask (x, mask)
41
+ end
42
+
43
+ function ChainRulesCore. rrule (:: typeof (_dropout), rng, x, p, dims, active)
44
+ mask = active ? dropout_mask (rng, x, p, dims) : nothing
45
+ # Required because we don't always call dropout_mask
46
+ MT = Core. Compiler. return_type (dropout_mask, Tuple{typeof (rng),typeof (x),typeof (p),typeof (dims)})
47
+ project_x = ProjectTo (x)
48
+ return _apply_mask (x, mask), DropoutPullback {MT,typeof(project_x)} (mask, project_x)
49
+ end
50
+
51
+ # Also needed for type stability. Otherwise inference lifts the Union into a
52
+ # Union{pullback{Nothing}, pullback{AbstractArray}}
53
+ struct DropoutPullback{M<: AbstractArray ,P<: ProjectTo{AbstractArray} }
54
+ mask:: Union{Nothing,M}
55
+ project:: P
56
+ end
57
+
58
+ function (pb:: DropoutPullback )(dy)
59
+ dx = pb. project (_apply_mask (dy, pb. mask))
60
+ return (NoTangent (), NoTangent (), dx, NoTangent ())
61
+ end
62
+
63
+ _apply_mask (x, :: Nothing ) = x
64
+ _apply_mask (x, mask) = x .* mask
65
+
66
+ dropout_mask (rng:: CUDA.RNG , x:: CuArray , p, dims) = _dropout_mask (rng, x, p, dims)
67
+ dropout_mask (rng, x:: CuArray , p, dims) =
68
+ throw (ArgumentError (" x isa CuArray, but rng isa $(typeof (rng)) . dropout_mask only supports CUDA.RNG for CuArrays." ))
69
+ dropout_mask (rng, x, p, dims) = _dropout_mask (rng, x, p, dims)
70
+ function _dropout_mask (rng, x, p, dims)
44
71
realfptype = float (real (eltype (x)))
45
72
y = rand! (rng, similar (x, realfptype, _dropout_shape (x, dims)))
46
73
y .= _dropout_kernel .(y, p, 1 - p)
47
74
return y
48
75
end
49
76
50
77
# TODO move this to NNlib
51
- ChainRulesCore. @non_differentiable dropout_mask (:: Any , :: Any , :: Any )
78
+ ChainRulesCore. @non_differentiable dropout_mask (:: Any , :: Any , :: Any , :: Any )
52
79
53
80
"""
54
81
Dropout(p; dims=:, rng = rng_from_array())
106
133
@functor Dropout
107
134
trainable (a:: Dropout ) = (;)
108
135
109
- function (a:: Dropout )(x)
110
- _isactive (a) || return x
111
- return dropout (a. rng, x, a. p; dims= a. dims, active= true )
112
- end
136
+ (a:: Dropout )(x) = _dropout (a. rng, x, a. p, a. dims, _isactive (a))
113
137
114
138
testmode! (m:: Dropout , mode= true ) =
115
139
(m. active = (isnothing (mode) || mode == :auto ) ? nothing : ! mode; m)
@@ -226,7 +250,7 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]
226
250
227
251
@functor LayerNorm
228
252
229
- (a:: LayerNorm )(x) = a. diag (normalise (x, dims = 1 : length (a. size), ϵ = a. ϵ))
253
+ (a:: LayerNorm )(x) = a. diag (_normalize (x, 1 : length (a. size), a. ϵ))
230
254
231
255
function Base. show (io:: IO , l:: LayerNorm )
232
256
print (io, " LayerNorm(" , join (l. size, " , " ))
0 commit comments