Skip to content

Commit 3172f13

Browse files
committed
give up on customising fmap & write the recursion, add evil tests
1 parent 522f66a commit 3172f13

File tree

2 files changed

+107
-60
lines changed

2 files changed

+107
-60
lines changed

src/interface.jl

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,28 @@ end
2020
Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b)
2121

2222
function setup(rule::AbstractRule, model)
23-
cnt = Ref(0)
24-
# Rely on Functors to identify shared arrays, they will share a Leaf in this tree:
25-
tree = fmapstructure(model, exclude = isnumeric) do x
26-
cnt[] += 1
27-
Leaf(rule, init(rule, x))
28-
end
29-
cnt[] == 0 && @warn "setup found no parameters in the given model"
23+
cache = IdDict()
24+
tree = _setup(rule, model; cache)
25+
isempty(cache) && @warn "setup found no trainable parameters in this model"
3026
tree
3127
end
3228

29+
# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc.
30+
function _setup(rule, x; cache)
31+
haskey(cache, x) && return cache[x]
32+
if isnumeric(x)
33+
= Leaf(rule, init(rule, x))
34+
if isbits(x)
35+
cache[nothing] = nothing # just to disable the warning
36+
37+
else
38+
cache[x] =
39+
end
40+
else
41+
map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
42+
end
43+
end
44+
3345
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
3446
ioc = IOContext(io, :compact => true)
3547
print(ioc, "Leaf(", ℓ.rule, ", ")
@@ -41,65 +53,56 @@ end
4153
### update
4254
###
4355

44-
function update!(tree, model, grad)
56+
function update(tree, model, grad, higher...)
57+
t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf
58+
x′ = fmap(copy, model; exclude = maywrite)
59+
update!(t′, x′, grad, higher...)
60+
end
61+
62+
function update!(tree, model, grad, higher...)
4563
# First walk is to accumulate the gradient. This recursion visits every copy of
4664
# shared leaves, but stops when branches are absent from the gradient:
47-
dict = IdDict{Leaf, Any}()
48-
grads!(dict, tree, model, grad)
49-
# Second walk is to update the model. The walk taken follows Leaf identity
50-
newmodel = fmap(tree, model; exclude =->isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x
51-
haskey(dict, ℓ) || return x # no gradient seen, nothing to do
52-
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
53-
.state = s′ # to get state out of here, rely on mutability of Leaf
65+
grads = IdDict{Leaf, Any}()
66+
_grads!(grads, tree, model, grad, higher...)
67+
# Second walk is to update the model. The params cache indexed by (tree,x),
68+
# so that identified Leafs can tie isbits parameters, but setup won't do that for you:
69+
newmodel = _update!(tree, model; grads, params = IdDict())
70+
tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree.
71+
end
72+
73+
function _update!(tree, x; grads, params)
74+
haskey(params, (tree,x)) && return params[(tree,x)]
75+
isbits(tree) && return x # means () is not cached, and also (((),),)
76+
x′, re = functor(x)
77+
x′′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′)
78+
params[(tree,x)] = re(x′′)
79+
end
80+
function _update!(ℓ::Leaf, x; grads, params)
81+
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
82+
params[(ℓ,x)] = if haskey(grads, ℓ)
83+
.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
5484
subtract!(x, x̄′)
85+
else
86+
x # no gradient seen
5587
end
56-
tree, newmodel # note that tree is guaranteed to be updated
5788
end
5889

5990
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
6091

61-
grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing
62-
function grads!(dict::IdDict, ℓ::Leaf, x, )
63-
= get(dict, ℓ, ZeroTangent())
64-
dict[ℓ] = + x̄₀ # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
92+
_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing
93+
function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
94+
x̄s= get(dict, ℓ, map(_ -> ZeroTangent(), x̄s))
95+
dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
6596
nothing
6697
end
67-
grads!(dict::IdDict, t, x, ::Zero) = nothing
68-
function grads!(dict::IdDict, tree, x, x̄s...)
69-
# The only reason grads! takes model is that functor(typeof(x), base(x̄)) may differ from
98+
_grads!(dict::IdDict, t, x, ::Zero...) = nothing
99+
function _grads!(dict::IdDict, tree, x, x̄s...)
100+
# The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from
70101
# functor(typeof(tree), base(x̄)), for things like Transpose
71102
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
72103
x′, _ = functor(typeof(x), x)
73-
foreach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
74-
end
75-
76-
function update(tree, x, x̄s...)
77-
t′ = fmap(copy, tree; exclude = maywrite) # goes inside Leaf
78-
x′ = fmap(copy, x; exclude = maywrite)
79-
update!(t′, x′, x̄s...)
80-
end
81-
82-
# This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first
83-
function _second_walk(f, x, y)
84-
x′, _ = functor(typeof(y), x)
85-
y′, re = functor(y)
86-
re(map(f, x′, y′))
87-
end
88-
89-
# When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state.
90-
# This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work.
91-
struct LeafCache <: AbstractDict{Leaf,Any}
92-
dict::IdDict{Leaf,Any}
104+
foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
93105
end
94-
LeafCache() = LeafCache(IdDict{Leaf,Any}())
95-
96-
Base.setindex!(c::LeafCache, x, ℓ::Leaf) = setindex!(c.dict, x, ℓ)
97-
Base.setindex!(c::LeafCache, x, _) = nothing
98-
Base.in(k, c::LeafCache) = k in c.dict
99-
Base.haskey(c::LeafCache, k) = haskey(c.dict, k)
100-
Base.getindex(c::LeafCache, ℓ::Leaf) = getindex(c.dict, ℓ)
101-
Base.iterate(c::LeafCache, i = 0) = iterate(c.dict, i)
102-
Base.length(c::LeafCache) = length(c.dict)
103106

104107
# default all rules to first order calls
105108
apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)

test/runtests.jl

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,26 @@ struct TwoThirds a; b; c; end
1313
Functors.@functor TwoThirds (a, c)
1414
Optimisers.trainable(x::TwoThirds) = (a = x.a,)
1515

16-
struct DummyHigherOrder <: AbstractRule end
16+
mutable struct MutTwo; x; y; end
17+
Functors.@functor MutTwo
1718

19+
struct DummyHigherOrder <: AbstractRule end
1820
Optimisers.init(::DummyHigherOrder, x::AbstractArray) =
1921
(ones(eltype(x), size(x)), zero(x))
20-
2122
dummy_update_rule(st, p, dx, dx2) = @. p - (st[1] * dx + st[2] * dx2)
2223
function Optimisers.apply!(::DummyHigherOrder, state, x, dx, dx2)
2324
a, b = state
2425
@.. dx = a * dx + b * dx2
25-
2626
return (a .+ 1, b .+ 1), dx
2727
end
2828

29+
struct BiRule <: Optimisers.AbstractRule end
30+
Optimisers.init(o::BiRule, x::AbstractArray) = nothing
31+
function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
32+
dx == dx2 || error("expected 1st & 2nd gradients to agree")
33+
return state, dx
34+
end
35+
2936
@testset verbose=true "Optimisers.jl" begin
3037
@testset verbose=true "Features" begin
3138

@@ -220,6 +227,23 @@ end
220227
@test_throws MethodError Optimisers.update(sm, m)
221228
end
222229

230+
@testset "2nd order gradient" begin
231+
m == ([1.0], sin), γ = Float32[4,3,2])
232+
233+
# Special rule which requires this:
234+
s = Optimisers.setup(BiRule(), m)
235+
g == ([0.1], ZeroTangent()), γ = [1,10,100],)
236+
s1, m1 = Optimisers.update(s, m, g, g)
237+
@test m1.α[1] == [0.9]
238+
@test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g))
239+
240+
# Ordinary rule which doesn't need it:
241+
s2 = Optimisers.setup(Adam(), m)
242+
s3, m3 = Optimisers.update(s2, m, g)
243+
s4, m4 = Optimisers.update(s2, m, g, g)
244+
@test m3.γ == m4.γ
245+
end
246+
223247
@testset "broadcasting macros" begin
224248
x = [1.0, 2.0]; y = [3,4]; z = [5,6]
225249
@test (@lazy x + y * z) isa Broadcast.Broadcasted
@@ -305,22 +329,42 @@ end
305329
# Error: no constructors for type Any
306330
@test_broken s4, t4 = Optimisers.update(stri, tri, g4)
307331
end
308-
332+
309333
@testset "artificial" begin
310334
# Interpret shared Leaf as implying shared parameters, even if this did not arise from shared arrays.
311335
# No API for setting this at the moment, but can construct one by hand:
312-
model = (a = [1,2.0], b = [1, 2.0], c = [1, 2.0], d = [1, 2.0])
313-
honest = Optimisers.setup(Momentum(0.1), model)
314-
trick = (a = honest.a, b = honest.a, c = honest.c, d= honest.d) # makes a & b shared
336+
model = (a = SA[1,2.0], b = SA[1, 2.0], c = SA[1, 2.0], d = SA[1, 2.0])
337+
auto = Optimisers.setup(Momentum(0.1), model)
338+
@test auto.a !== auto.b # not tied just by value
339+
340+
trick = (a = auto.a, b = auto.a, c = auto.c, d= auto.d) # makes a & b tied
315341

316342
trick2, model2 = Optimisers.update(trick, model, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10]))
317343
trick3, model3 = Optimisers.update(trick2, model2, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10]))
318344

319345
@test model3.a == model3.b == model3.d # same as having the gradients added
320346
@test !(model3.a model3.c)
321347
@test trick3.a === trick3.b # leaves remain shared
322-
model3.a === model3.b # in fact arrays end up shared, but this is not required
323348
end
349+
350+
@testset "mutable containers" begin
351+
tmp = MutTwo([1.0], [2.0])
352+
model = (a=tmp, b=tmp, c=MutTwo(tmp.x, tmp.y))
353+
state = Optimisers.setup(Momentum(), model)
354+
355+
@test model.a === model.b
356+
@test model.a !== model.c # fields are identified, but struct is not
357+
358+
@test state.a.x === state.b.x
359+
@test state.a === state.b
360+
@test state.a === state.c # unavoidable, but means we can't use leaf ID alone
361+
362+
mgrad = (a=(x=[1], y=[10]), b=(x=[100], y=[1000]), c=(x=[1/3], y=[1/30]))
363+
state2, model2 = Optimisers.update(state, model, mgrad)
364+
365+
@test model2.a === model2.b # tie of MutTwo structs is restored
366+
@test model2.a !== model2.c # but a new tie is not created
367+
end
324368
end
325369

326370
@testset "higher order interface" begin

0 commit comments

Comments
 (0)