Skip to content

Commit 7ab871f

Browse files
authored
RFC: add Functors-aware structural gradient (#129)
* withgradient * drop tests on 1.3 * use _trainable_walk * wtf
1 parent eeb0ae8 commit 7ab871f

File tree

5 files changed

+98
-11
lines changed

5 files changed

+98
-11
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ jobs:
2020
fail-fast: false
2121
matrix:
2222
version:
23-
- '1.3'
2423
- '1.6' # LTS
2524
- '1'
2625
- 'nightly'

Project.toml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,37 @@
11
name = "Tracker"
22
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3-
version = "0.2.20"
3+
version = "0.2.21"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1112
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1213
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1314
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
15+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1416
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1517
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1618
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1719
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1820
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1921

2022
[compat]
21-
Adapt = "1, 2, 3"
23+
Adapt = "3"
2224
DiffRules = "1.4"
25+
Functors = "0.3.0"
2326
ForwardDiff = "0.10"
2427
LogExpFunctions = "0.3"
2528
MacroTools = "0.5"
26-
NNlib = "0.7.18, 0.8" # 0.7.18 is the last version which supports Julia 1.3
27-
NaNMath = "0.3, 1"
28-
Requires = "0.5, 1.0"
29-
SpecialFunctions = "0.10, 1, 2"
30-
julia = "1.3"
29+
NNlib = "0.8"
30+
NaNMath = "1"
31+
Optimisers = "0.2.9"
32+
Requires = "1.0"
33+
SpecialFunctions = "1, 2"
34+
julia = "1.6"
3135

3236
[extras]
3337
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"

src/Tracker.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import Printf
1313
import Base: ==
1414

1515
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
16-
jacobian, hessian, param, back!
16+
jacobian, hessian, param, back!, withgradient
1717

1818
tracker(x) = nothing
1919

@@ -70,10 +70,10 @@ end
7070

7171
include("idset.jl")
7272
include("params.jl")
73-
include("back.jl")
74-
include("numeric.jl")
7573
include("lib/real.jl")
7674
include("lib/array.jl")
75+
include("back.jl")
76+
include("numeric.jl")
7777
include("forward.jl")
7878
@init @require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("lib/pdmats.jl")
7979

src/back.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,71 @@ function jacobian(f, x::AbstractVector)
178178
end
179179

180180
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)
181+
182+
using Functors: fmap, functor
183+
using Optimisers: _trainable, isnumeric
184+
185+
"""
186+
withgradient(f, xs...)
187+
188+
This computes the value `f(xs...)` and the gradient with respect to `xs`.
189+
However, it differs from `gradient` in several other respects:
190+
* It will recurse into `xs` using `fmap`, and thus like Zygote's "explicit mode" it
191+
returns a tree-like gradient matching the shape of a Flux model.
192+
This recursion obeys restrictions imposed by `Optimisers.trainable`, if defined.
193+
* Only objects satisfying `Optimisers.isnumeric` are regarded as parameters,
194+
thus in particular integers are ignored.
195+
* Returns plain arrays, not tracked. Uses `nothing` as a strong zero gradient, like Zygote.
196+
197+
# Examples
198+
```
199+
julia> nt = (vec = [1.0, 2.0], mat = [4.0;;], fun = sin);
200+
201+
julia> withgradient(nt, 2) do x, p
202+
sum(abs2, x.vec) ^ p
203+
end
204+
(val = 25.0, grad = ((vec = [20.0, 40.0], mat = [0.0;;], fun = nothing), nothing))
205+
206+
julia> using Flux
207+
208+
julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1, bias=false));
209+
210+
julia> withgradient(model, rand(Float32, 2)) do m, x
211+
sum(abs2, m(x))
212+
end
213+
(val = 0.035716165f0, grad = ((layers = ((weight = Float32[-0.4241869 -0.16741231], bias = Float32[-0.5529184], σ = nothing), (weight = Float32[-0.04804218;;], bias = nothing, σ = nothing)),), Float32[0.12706584, -0.08858479]))
214+
```
215+
"""
216+
function withgradient(f, xs...)
217+
pxs = fmap(param, xs; exclude = isnumeric, walk = _trainable_walk)
218+
l = f(pxs...)
219+
losscheck(l)
220+
l isa TrackedReal || return (val = l, grad = nothing)
221+
@interrupts back!(l)
222+
(val = data(l), grad = rec_grad(pxs))
223+
end
224+
225+
function _trainable_walk(f, x)
226+
func, re = functor(x)
227+
isempty(func) && return x
228+
done = map(f, _trainable(x)) # recurse only into trainable fields, this contains `nothing` elsewhere
229+
map(func, merge(func, done)) do n, t
230+
isnothing(t) ? n : t
231+
end |> re # reconstruct the whole thing
232+
end
233+
_trainable_walk(f, x::Tuple) = map(f, x)
234+
235+
# Easier to write the recursion to extract the gradients without using fmap:
236+
rec_grad(x::TrackedArray) = grad(x)
237+
rec_grad(x::TrackedReal) = grad(x)
238+
rec_grad(x::AbstractArray{<:Number}) = nothing
239+
rec_grad(x::Number) = nothing
240+
241+
rec_grad(x::Union{Tuple,NamedTuple,AbstractArray}) = map(rec_grad, x)
242+
rec_grad(::Tuple{}) = nothing
243+
rec_grad(::NamedTuple{(), Tuple{}}) = nothing
244+
function rec_grad(x::T) where {T}
245+
F = fieldnames(T)
246+
isempty(F) && return nothing
247+
map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F))
248+
end

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,20 @@ using Tracker: jacobian
1717
@test J A.data
1818
end
1919

20+
using Optimisers, Functors
21+
struct TwoThirds a; b; c; end # evil test from Optimisers.jl
22+
@eval Functors.@functor TwoThirds (a, c)
23+
Optimisers.trainable(x::TwoThirds) = (a = x.a,)
24+
25+
@testset "withgradient" begin
26+
nt = (vec = [1.0, 2.0], mat = [4.0;;], fun = sin);
27+
@test withgradient((x, p) -> sum(abs2, x.vec) ^ p, nt, 2) == (val = 25.0, grad = ((vec = [20.0, 40.0], mat = [0.0;;], fun = nothing), nothing))
28+
29+
@test withgradient(x -> sum(x.v), (v = [1, 2], w = [3.0])) == (val = 3, grad = nothing)
30+
31+
m = TwoThirds([1.0], [2.0], [3.0]) # only the first should be tracked, but all should survive
32+
g = withgradient(m -> only(m.a::AbstractVector + m.b::Vector + m.c::Vector), m)
33+
@test g == (val = 6.0, grad = ((a = [1.0], b = nothing, c = nothing),))
2034
end
35+
36+
end # overall @testset

0 commit comments

Comments
 (0)