Skip to content

Commit 5f63dbf

Browse files
authored
Add within_gradient (#434)
* add within_gradient * add ForwardDiff method * docs * use in softmax too
1 parent 2bae421 commit 5f63dbf

File tree

7 files changed

+69
-5
lines changed

7 files changed

+69
-5
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2222
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2323
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2424
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
25+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2526
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2627
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
2728
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -31,4 +32,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
3132
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3233

3334
[targets]
34-
test = ["ChainRulesTestUtils", "CUDA", "Documenter", "FiniteDifferences", "Logging", "NNlibCUDA", "Random", "StableRNGs", "Test", "UnicodePlots", "Zygote"]
35+
test = ["ChainRulesTestUtils", "CUDA", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "NNlibCUDA", "Random", "StableRNGs", "Test", "UnicodePlots", "Zygote"]

docs/src/reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,5 @@ ctc_loss
132132
```@docs
133133
logsumexp
134134
NNlib.glu
135+
NNlib.within_gradient
135136
```

src/NNlib.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ export upsample_nearest, ∇upsample_nearest,
8787
include("gather.jl")
8888
include("scatter.jl")
8989
include("utils.jl")
90+
@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
91+
using .ForwardDiff
92+
within_gradient(x::ForwardDiff.Dual) = true
93+
within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true
94+
end
95+
9096
include("sampling.jl")
9197
include("functions.jl")
9298

src/softmax.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
6969
end
7070

7171
function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
72-
dx = if within_grad()
72+
dx = if within_gradient(y)
7373
tmp = dy .* y
7474
tmp .- y .* sum(tmp; dims)
7575
else
@@ -88,9 +88,6 @@ function rrule(::typeof(softmax), x; dims = 1)
8888
return y, softmax_pullback
8989
end
9090

91-
within_grad() = false
92-
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)
93-
9491
fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))
9592

9693
"""

src/utils.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,55 @@
1+
"""
2+
within_gradient(x) --> Bool
3+
4+
Returns `false` except when used inside a `gradient` call, when it returns `true`.
5+
Useful for Flux regularisation layers which behave differently during training and inference.
6+
7+
This should work with any ChainRules-based differentiation package, in which case `x` is ignored.
8+
But Tracker.jl overloads `with_gradient(x::TrackedArray)`, thus for widest use you should
9+
pass it an array whose gradient is of interest.
10+
There is also an overload for ForwardDiff.jl's `Dual` types (and arrays of them).
11+
12+
# Examples
13+
```
14+
julia> using ForwardDiff, Zygote, NNlib
15+
16+
julia> f_good(x) = if NNlib.within_gradient(x)
17+
@show 10x
18+
else
19+
x
20+
end;
21+
22+
julia> Zygote.withgradient(f_good, 1.0)
23+
10x = 10.0
24+
(val = 10.0, grad = (10.0,))
25+
26+
julia> ForwardDiff.derivative(f_good, 1.0)
27+
10x = Dual{ForwardDiff.Tag{typeof(f_good), Float64}}(10.0,10.0)
28+
10.0
29+
30+
julia> f_bad(x, y) = if any(NNlib.within_gradient, (x, y))
31+
@show x * y
32+
else
33+
x / y
34+
end;
35+
36+
julia> Zygote.withgradient(f_bad, 2.0, 3.0)
37+
(val = 0.6666666666666666, grad = (0.3333333333333333, -0.2222222222222222))
38+
39+
julia> ForwardDiff.derivative(x -> f_bad(x, 3.0), 2.0)
40+
x * y = Dual{ForwardDiff.Tag{var"#9#10", Float64}}(6.0,3.0)
41+
3.0
42+
```
43+
44+
What goes wrong in `f_bad` is that Zygote knows `any` to be non-differentiable,
45+
and thus completely ignores its contents. This is not a perfect mechanism,
46+
and the only style recommended is precisely that of `f_good` above.
47+
"""
48+
within_gradient(x) = false
49+
50+
ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent())
51+
52+
153
"""
254
safe_div(x, y)
355

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using NNlib, Test, Statistics, Random
22
using ChainRulesCore, ChainRulesTestUtils
33
using Base.Broadcast: broadcasted
44
import FiniteDifferences
5+
import ForwardDiff
56
import Zygote
67
using Zygote: gradient
78
using StableRNGs

test/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
@testset "within_gradient" begin
2+
@test NNlib.within_gradient([1.0]) === false
3+
@test gradient(x -> NNlib.within_gradient(x) * x, 2.0) == (1.0,)
4+
@test NNlib.within_gradient([ForwardDiff.Dual(1.0, 2)]) === true
5+
end
6+
17
@testset "maximum_dims" begin
28
ind1 = [1,2,3,4,5,6]
39
@test NNlib.maximum_dims(ind1) == (6,)

0 commit comments

Comments
 (0)