Skip to content

Commit 7997174

Browse files
authored
use NNlib.within_gradient (#2152)
1 parent aba285c commit 7997174

File tree

5 files changed

+32
-9
lines changed

5 files changed

+32
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ChainRulesCore = "1.12"
3030
Functors = "0.3, 0.4"
3131
MLUtils = "0.2, 0.3.1, 0.4"
3232
MacroTools = "0.5"
33-
NNlib = "0.8.9"
33+
NNlib = "0.8.14"
3434
NNlibCUDA = "0.2.4"
3535
OneHotArrays = "0.1, 0.2"
3636
Optimisers = "0.2.12"

src/cuda/cudnn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
88
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels"
99
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
1010
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
11-
training=Flux._isactive(BN)))
11+
training=Flux._isactive(BN, x)))
1212
end
1313

1414
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)

src/deprecations.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed.
8686

8787
@deprecate rng_from_array() default_rng_value()
8888

89+
function istraining()
90+
Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining)
91+
false
92+
end
93+
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
94+
95+
function _isactive(m)
96+
Base.depwarn("_isactive(m) is deprecated, use _isactive(m,x)", :_isactive, force=true)
97+
_isactive(m, 1:0)
98+
end
99+
89100
#=
90101
# Valid method in Optimise, old implicit style, is:
91102
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())

src/layers/normalise.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
istraining() = false
21

3-
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
4-
5-
_isactive(m) = isnothing(m.active) ? istraining() : m.active
2+
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active
63

74
_dropout_shape(s, ::Colon) = size(s)
85
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
@@ -107,7 +104,7 @@ end
107104
trainable(a::Dropout) = (;)
108105

109106
function (a::Dropout)(x)
110-
_isactive(a) || return x
107+
_isactive(a, x) || return x
111108
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
112109
end
113110

@@ -162,7 +159,7 @@ AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
162159
trainable(a::AlphaDropout) = (;)
163160

164161
function (a::AlphaDropout)(x::AbstractArray{T}) where T
165-
_isactive(a) || return x
162+
_isactive(a, x) || return x
166163
p = a.p
167164
iszero(p) && return x
168165
isone(p) && return sign.(x) .* T(0)
@@ -242,7 +239,7 @@ end
242239
function _norm_layer_forward(
243240
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
244241
) where {T, N}
245-
if !_isactive(l) && l.track_stats # testmode with tracked stats
242+
if !_isactive(l, x) && l.track_stats # testmode with tracked stats
246243
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
247244
μ = reshape(l.μ, stats_shape)
248245
σ² = reshape(l.σ², stats_shape)

test/layers/normalisation.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,5 +475,20 @@ end
475475
# This was an error, https://github.com/FluxML/Flux.jl/issues/2122
476476
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
477477
@test !iszero(bn.μ)
478+
479+
# Easy case of 2122, gradient with x
480+
x5 = rand(Float32, 5, 3)
481+
bn1 = BatchNorm(5, relu)
482+
bn2 = BatchNorm(5, relu)
483+
g1 = Zygote.gradient(x -> sum(abs2, bn1(x)), x5)[1]
484+
g2 = ForwardDiff.gradient(x -> sum(abs2, bn2(x)), x5)
485+
@test g1 g2
486+
487+
# Harder case?
488+
v1, re1 = Flux.destructure(BatchNorm(5, relu));
489+
g1 = Zygote.gradient(v -> sum(abs2, re1(v)(x5)), v1)[1]
490+
491+
v2, re2 = Flux.destructure(BatchNorm(5, relu));
492+
g2 = ForwardDiff.gradient(v -> sum(abs2, re2(v)(x5)), v2)
478493
end
479494

0 commit comments

Comments
 (0)