Skip to content

Commit 640765a

Browse files
mcabbottToucheSir
andcommitted
allow shared parameters, take III
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
1 parent c94abf6 commit 640765a

File tree

4 files changed

+152
-45
lines changed

4 files changed

+152
-45
lines changed

src/Optimisers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Optimisers
22

3-
using Functors: functor, fmap, isleaf
3+
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children
44
using LinearAlgebra
55

66
include("interface.jl")
@@ -157,8 +157,8 @@ true
157157
julia> m # original should be discarded, may be mutated but no guarantee
158158
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
159159
160-
julia> t # original state should likewise be discarded
161-
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.333333, 0.466667]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
160+
julia> t == t2 # original state is in fact guaranteed to be mutated
161+
true
162162
```
163163
"""
164164
update!

src/adjust.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
4747
adjust(::Nothing, ::Real) = nothing
4848
adjust(::Nothing; kw...) = nothing
4949

50-
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state)
51-
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state)
50+
adjust(ℓ::Leaf, eta::Real) = .frozen ?: Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
51+
adjust(ℓ::Leaf; kw...) = .frozen ?: Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)
5252

5353

5454
"""

src/interface.jl

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,94 @@
11

2-
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero
2+
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent
33
base(dx::Tangent) = backing(canonicalize(dx))
44
base(dx) = dx
55
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
66

77
abstract type AbstractRule end
88

9-
struct Leaf{R,S}
9+
###
10+
### setup
11+
###
12+
13+
mutable struct Leaf{R,S}
1014
rule::R
1115
state::S
16+
frozen::Bool
1217
end
1318

14-
function setup(rule, x; seen = Base.IdSet())
15-
rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup)
16-
if isnumeric(x)
17-
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
18-
isbits(x) || push!(seen, x)
19-
return Leaf(rule, init(rule, x))
20-
elseif isleaf(x)
21-
return nothing
22-
else
23-
return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x))
19+
@functor Leaf
20+
21+
Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b)
22+
23+
function setup(rule::AbstractRule, model)
24+
cnt = Ref(0)
25+
# Rely on Functors to identify shared arrays, they will share a Leaf in this tree:
26+
tree = fmapstructure(model, exclude = isnumeric) do x
27+
cnt[] += 1
28+
Leaf(rule, init(rule, x), false)
2429
end
30+
cnt[] == 0 && @warn "setup found no parameters in the given model"
31+
tree
2532
end
2633

27-
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
28-
29-
update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
30-
update!(::Nothing, x, x̄s...) = nothing, x
34+
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
35+
ioc = IOContext(io, :compact => true)
36+
print(ioc, "Leaf(", ℓ.rule, ", ")
37+
show(ioc, ℓ.state)
38+
print(ioc, ", ", ℓ.frozen, ")")
39+
end
3140

32-
update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x
33-
function update!(ℓ::Leaf, x, x̄s...)
34-
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
35-
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
41+
###
42+
### update
43+
###
44+
45+
function update!(tree, model, grad)
46+
# First walk is to accumulate the gradient. This recursion visits every copy of
47+
# shared leaves, but stops when branches are absent from the gradient:
48+
dict = IdDict{Leaf, Any}()
49+
grads!(dict, tree, model, grad)
50+
# Second walk is to update the model, using same fmap walk as setup, thus each Leaf exactly once:
51+
newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ
52+
isa Leaf || error("this state does not match the model, expected a Leaf here")
53+
.frozen && return x
54+
haskey(dict, ℓ) || return x
55+
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
56+
.state = s′ # to get state out of here, rely on mutability of Leaf
57+
subtract!(x, x̄′)
58+
end
59+
tree, newmodel # note that tree is guaranteed to be updated
3660
end
3761

38-
update!(tree, x, ::Zero, ::Zero...) = tree, x
39-
function update!(tree, x, x̄s...)
62+
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
63+
64+
grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing
65+
function grads!(dict::IdDict, ℓ::Leaf, x, x̄)
66+
x̄₀ = get(dict, ℓ, false)
67+
dict[ℓ] = Broadcast.broadcasted(+, x̄, x̄₀)
68+
nothing
69+
end
70+
grads!(dict::IdDict, t, x, ::Zero) = nothing
71+
function grads!(dict::IdDict, tree, x, x̄s...)
72+
# The only reason grads! takes model is that functor(typeof(x), base(x̄)) may differ from
73+
# functor(typeof(tree), base(x̄)), for things like Transpose
4074
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
41-
x′, re = functor(typeof(x), x)
42-
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
43-
map(first, xtree), re(map(last, xtree))
75+
x′, _ = functor(typeof(x), x)
76+
foreach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
4477
end
4578

4679
function update(tree, x, x̄s...)
47-
t′ = fmap(copy, tree; exclude = maywrite)
80+
t′ = fmap(copy, tree; exclude = maywrite) # goes inside Leaf
4881
x′ = fmap(copy, x; exclude = maywrite)
4982
update!(t′, x′, x̄s...)
5083
end
5184

5285
# default all rules to first order calls
5386
apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)
5487

88+
###
89+
### sources of truth
90+
###
91+
5592
"""
5693
isnumeric(x) -> Bool
5794
@@ -98,6 +135,10 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu
98135
map(c -> c in tr ? c : nothing, ch)
99136
end
100137

138+
###
139+
### rule definition helpers
140+
###
141+
101142
"""
102143
@.. x = x + y
103144
@@ -135,11 +176,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
135176

136177
onevalue::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
137178
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)
138-
139-
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
140-
ioc = IOContext(io, :compact => true)
141-
print(ioc, "Leaf(", ℓ.rule, ", ")
142-
show(ioc, ℓ.state)
143-
print(io, ")")
144-
end
145-

test/runtests.jl

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
3535
g4 = Tangent{typeof(m)}(g...)
3636
s4, m4 = Optimisers.update!(s, ([1.0, 2.0],), g4)
3737
@test m4[1] [1,2] .- 0.1 .* [25, 33]
38+
39+
o5 = Momentum(0.1)
40+
s5 = Optimisers.setup(o5, m)
41+
42+
s6, m6 = Optimisers.update(s5, m, g)
43+
@test s6[1].state [2.5, 3.3]
44+
@test s5[1].state == [0, 0] # not mutated -- wrong on v0.2.9
45+
46+
s7, m7 = Optimisers.update!(s5, m, g)
47+
@test s7[1].state === s5[1].state # same array
48+
@test s7[1] === s5[1] # same Leaf
3849
end
3950

4051
@testset "gradient clipping" begin
@@ -212,12 +223,75 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
212223
end
213224

214225
@testset "tied weights" begin
215-
ok = (1.0:3.0, sin, "abc", :abc)
216-
m == ok, β = rand(3), γ = ok)
217-
m1 = (rand(3), m, rand(3))
218-
@test Optimisers.setup(AdamW(), m1) isa Tuple
219-
m2 = (rand(3), m, rand(3), m, rand(3)) # illegal
220-
@test_throws ArgumentError Optimisers.setup(AdamW(), m2)
226+
@testset "tuples" begin
227+
twice = [1,2.0]
228+
mtup = (twice, (copy(twice), twice)) # (tied (not tied, tied))
229+
230+
# simplest rule for which opt(g1) + opt(g2) != opt(g1 + g2)
231+
stup = Optimisers.setup(Momentum(0.1), mtup)
232+
gtup = ([3,3], ([10,10], [7,7])) # (g1, (g1 + g2, g2))
233+
234+
snew, mnew = Optimisers.update(stup, mtup, gtup)
235+
@test mnew[1] mnew[2][1] # gradient was accumulated
236+
@test mnew[2][2] === mnew[1] # and tie is not broken
237+
238+
st3, mt3 = Optimisers.update(stup, mtup, ([3,3], nothing))
239+
@test mt3[1] [1,2] - 0.1 * [3,3]
240+
@test mt3[2][2] === mt3[1]
241+
242+
st4, mt4 = Optimisers.update(stup, mtup, (nothing, ([5,5], [7,7])))
243+
@test mt4[1] [1,2] - 0.1 * [7,7]
244+
end
245+
246+
@testset "named" begin
247+
thrice = [3f0]
248+
model = (a = (x = thrice, y = Float32[4,5,6], z = true), b = ((m = (0, 1, thrice),),), c = (x = Float32[7,8], y = thrice))
249+
tree = Optimisers.setup(Momentum(0.1, 0.9), model)
250+
@test model.a.x === model.b[1].m[3] == model.c.y
251+
252+
loss(x::Array) = sum(abs2, x)
253+
loss(x::Number) = x^3
254+
loss(m) = sum(2 * loss(x) for x in m)
255+
gradient(loss, model)
256+
_, m2 = Optimisers.update(tree, model, gradient(loss, model)...)
257+
@test m2.a.x === m2.b[1].m[3] == m2.c.y
258+
259+
loss3(m) = sum(x isa Tuple ? 0 : 2 * loss(x) for x in m)
260+
gradient(loss3, model) # truncates the b limb
261+
_, m3 = Optimisers.update(tree, model, gradient(loss3, model)...)
262+
@test m3.a.x === m3.b[1].m[3] == m3.c.y
263+
end
264+
265+
@testset "transpose" begin
266+
mat = [1 2 3; 4 5 6.0]
267+
bidir = (m = mat, f = log, t = transpose(mat), v = [7, 8, 9.0])
268+
bigrad, _ = gradient((m, x) -> sum(abs2, m.m * (m.f).(m.t*x .+ m.v)), bidir, [1, 0.1])
269+
@test bigrad.t isa Matrix # not a Transpose, that's the point here
270+
271+
state = Optimisers.setup(Descent(0.1), bidir)
272+
@test state.t.parent === state.m # successfully tied
273+
274+
s2, b2 = Optimisers.update(state, bidir, bigrad)
275+
@test b2.t.parent === b2.m # tie restored
276+
@test b2.m bidir.m - 0.1 * (bigrad.m + transpose(bigrad.t)) # grad accumulated
277+
278+
state = Optimisers.setup(OptimiserChain(ClipGrad(10), Descent(0.1), ClipGrad(10)), bidir)
279+
s2, b2 = Optimisers.update(state, bidir, bigrad)
280+
@test b2.t.parent === b2.m
281+
@test b2.m bidir.m - 0.1 * clamp.((bigrad.m + transpose(bigrad.t)), -10, 10)
282+
283+
# Similar, but now "primary" field is the transposed one:
284+
tri = (a = transpose(mat), b = mat, c = transpose(mat), d = 4.0)
285+
trigrad = gradient(m -> sum(abs2, m.a * (m.b * (m.c * [0.1, 1] .+ m.d) .- m.d)), tri)[1]
286+
stri = Optimisers.setup(Descent(0.1), tri)
287+
s3, t3 = Optimisers.update(stri, tri, trigrad)
288+
@test t3.a.parent === t3.b === t3.c.parent
289+
@test t3.a tri.a - 0.1 * (trigrad.a + trigrad.b' + trigrad.c)
290+
291+
g4 = (a = Broadcast.broadcasted(+, mat', 1), b = nothing, c = @thunk(mat' .+ 1), d = nothing)
292+
# Error: no constructors for type Any
293+
@test_broken s4, t4 = Optimisers.update(stri, tri, g4)
294+
end
221295
end
222296

223297
end

0 commit comments

Comments
 (0)