From d90e581b71a7838b13487a28cd4c87bc9c05bd57 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 9 Jul 2022 18:51:14 -0700 Subject: [PATCH 1/2] Transparent handling of tied weights This makes `Leaf` a mutable type so that tied weights are represented by the same leaf instance. Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/interface.jl | 70 ++++++++++++++++++++++++++++++++++-------- test/runtests.jl | 79 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 128 insertions(+), 21 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index f20c7370..168f9c8d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -6,24 +6,27 @@ const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor} abstract type AbstractRule end -struct Leaf{R,S} +mutable struct Leaf{R,S} rule::R state::S end -function setup(rule, x; seen = Base.IdSet()) +function setup(rule, x; cache = IdDict{Any,Leaf}()) rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup) if isnumeric(x) - x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry.")) - isbits(x) || push!(seen, x) - return Leaf(rule, init(rule, x)) + leaf = get(cache, x, missing) + ismissing(leaf) || return leaf + leaf = Leaf(rule, init(rule, x)) + isbits(x) || (cache[x] = leaf) + return leaf elseif isleaf(x) return nothing else - return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x)) + return map(xᵢ -> setup(rule, xᵢ; cache), _trainable(x)) end end +_add!(x, x̄) = iswriteable(x) ? (x .= x .+ x̄) : eltype(x).(x .+ x̄) subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x @@ -31,22 +34,63 @@ update!(::Nothing, x, x̄s...) = nothing, x update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x function update!(ℓ::Leaf, x, x̄s...) - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...) - Leaf(ℓ.rule, s′), subtract!(x, x̄′) + ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, map(base, x̄s)...) + return ℓ, subtract!(x, x̄′) end update!(tree, x, ::Zero, ::Zero...) = tree, x function update!(tree, x, x̄s...) + cache = IdDict{Leaf,Any}() + _accumulate!(cache, tree, x, x̄s...) + return UpdateCallback(cache, IdDict{Leaf,Any}())(tree, x, x̄s...) +end + +_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, _...) = nothing +_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, ::Zero, ::Zero...) = nothing +_accumulate!(::AbstractDict{Leaf,Any}, ℓ::Leaf, _, ::Zero, ::Zero...) = nothing +_accumulate!(::AbstractDict{Leaf,Any}, _, _, ::Zero, ::Zero...) = nothing + +function _accumulate!(cache::AbstractDict{Leaf,Any}, ℓ::Leaf, _, x̄s...) + acc_x̄s = get(cache, ℓ, missing) + cache[ℓ] = ismissing(acc_x̄s) ? x̄s : map(_add!, acc_x̄s, x̄s) + return +end +function _accumulate!(cache::AbstractDict{Leaf,Any}, tree, x, x̄s...) + x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) + x′, _ = functor(typeof(x), x) + foreach((stᵢ, xᵢ, x̄sᵢ...) -> _accumulate!(cache, stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) +end + +# slightly cleaner way of closing over update! internal state +struct UpdateCallback + acc_grads::IdDict{Leaf,Any} + param_cache::IdDict{Leaf,Any} +end + +(::UpdateCallback)(::Nothing, x, x̄s...) = nothing, x +(::UpdateCallback)(::Nothing, x, ::Zero, ::Zero...) = nothing, x +(::UpdateCallback)(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x +(::UpdateCallback)(tree, x, ::Zero, ::Zero...) = tree, x + +(cb::UpdateCallback)(ℓ::Leaf, x, x̄s...) = get!(cb.param_cache, ℓ) do + update!(ℓ, x, pop!(cb.acc_grads, ℓ)...) +end +function (cb::UpdateCallback)(tree, x, x̄s...) x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) x′, re = functor(typeof(x), x) - xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) - map(first, xtree), re(map(last, xtree)) + xtree = map(cb, tree, x′, x̄s′...) + return map(first, xtree), re(map(last, xtree)) end function update(tree, x, x̄s...) - t′ = fmap(copy, tree; exclude = iswriteable) - x′ = fmap(copy, x; exclude = iswriteable) - update!(t′, x′, x̄s...) + # because we rely on Leaf identity for tied parameters, they require special treatment + cache = IdDict() + tree′ = fmap(tree; cache, exclude = Base.Fix2(isa, Leaf)) do ℓ + Leaf(ℓ.rule, fmap(copy, ℓ.state; cache, exclude = iswriteable)) + end + x′ = fmap(copy, x; cache = empty!(cache), exclude = iswriteable) + x̄s′ = fmap(copy, x̄s; cache = empty!(cache), exclude = iswriteable) + return update!(tree′, x′, x̄s′...) end # default all rules to first order calls diff --git a/test/runtests.jl b/test/runtests.jl index f68a2e62..c2c6e886 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) g = ([25, 33],) o = Descent(0.1) s = Optimisers.setup(o, m) - + s2, m2 = Optimisers.update(s, m, g) @test m[1] == 1:2 # not mutated @test Optimisers.iswriteable(m[1]) @@ -157,13 +157,76 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) end @testset "tied weights" begin - ok = (1.0:3.0, sin, "abc", :abc) - m = (α = ok, β = rand(3), γ = ok) - m1 = (rand(3), m, rand(3)) - @test Optimisers.setup(AdamW(), m1) isa Tuple - m2 = (rand(3), m, rand(3), m, rand(3)) # illegal - @test_throws ArgumentError Optimisers.setup(AdamW(), m2) - end + @testset "tuples" begin + twice = [1,2.0] + mtup = (twice, (copy(twice), twice)) # (tied (not tied, tied)) + + # simplest rule for which opt(g1) + opt(g2) != opt(g1 + g2) + stup = Optimisers.setup(Momentum(0.1), mtup) + gtup = ([3,3], ([10,10], [7,7])) # (g1, (g1 + g2, g2)) + + snew, mnew = Optimisers.update(stup, mtup, gtup) + @test mnew[1] ≈ mnew[2][1] # gradient was accumulated + @test mnew[2][2] === mnew[1] # and tie is not broken + + st3, mt3 = Optimisers.update(stup, mtup, ([3,3], nothing)) + @test mt3[1] ≈ [1,2] - 0.1 * [3,3] + @test mt3[2][2] === mt3[1] + + st4, mt4 = Optimisers.update(stup, mtup, (nothing, ([5,5], [7,7]))) + @test mt4[1] ≈ [1,2] - 0.1 * [7,7] + end + + @testset "named" begin + thrice = [3f0] + model = (a = (x = thrice, y = Float32[4,5,6], z = true), b = ((m = (0, 1, thrice),),), c = (x = Float32[7,8], y = thrice)) + tree = Optimisers.setup(Momentum(0.1, 0.9), model) + @test model.a.x === model.b[1].m[3] == model.c.y + + loss(x::Array) = sum(abs2, x) + loss(x::Number) = x^3 + loss(m) = sum(2 * loss(x) for x in m) + gradient(loss, model) + _, m2 = Optimisers.update(tree, model, gradient(loss, model)...) + @test m2.a.x === m2.b[1].m[3] == m2.c.y + + loss3(m) = sum(x isa Tuple ? 0 : 2 * loss(x) for x in m) + gradient(loss3, model) # truncates the b limb + _, m3 = Optimisers.update(tree, model, gradient(loss3, model)...) + @test m3.a.x === m3.b[1].m[3] == m3.c.y + end + + @testset "transpose" begin + mat = [1 2 3; 4 5 6.0] + bidir = (m = mat, f = log, t = transpose(mat), v = [7, 8, 9.0]) + bigrad, _ = gradient((m, x) -> sum(abs2, m.m * (m.f).(m.t*x .+ m.v)), bidir, [1, 0.1]) + @test bigrad.t isa Matrix # not a Transpose, that's the point here + + state = Optimisers.setup(Descent(0.1), bidir) + @test state.t.parent === state.m # successfully tied + + s2, b2 = Optimisers.update(state, bidir, bigrad) + @test b2.t.parent === b2.m # tie restored + @test b2.m ≈ bidir.m - 0.1 * (bigrad.m + transpose(bigrad.t)) # grad accumulated + + state = Optimisers.setup(OptimiserChain(ClipGrad(10), Descent(0.1), ClipGrad(10)), bidir) + s2, b2 = Optimisers.update(state, bidir, bigrad) + @test b2.t.parent === b2.m + @test b2.m ≈ bidir.m - 0.1 * clamp.((bigrad.m + transpose(bigrad.t)), -10, 10) + + # Similar, but now "primary" field is the transposed one: + tri = (a = transpose(mat), b = mat, c = transpose(mat), d = 4.0) + trigrad = gradient(m -> sum(abs2, m.a * (m.b * (m.c * [0.1, 1] .+ m.d) .- m.d)), tri)[1] + stri = Optimisers.setup(Descent(0.1), tri) + s3, t3 = Optimisers.update(stri, tri, trigrad) + @test t3.a.parent === t3.b === t3.c.parent + @test t3.a ≈ tri.a - 0.1 * (trigrad.a + trigrad.b' + trigrad.c) + + g4 = (a = Broadcast.broadcasted(+, mat', 1), b = nothing, c = @thunk(mat' .+ 1), d = nothing) + # Error: no constructors for type Any + @test_broken s4, t4 = Optimisers.update(stri, tri, g4) + end + end # tied weights end @testset verbose=true "Destructure" begin From cfa95810a0a9529eccad7a98dabe3506f35aa77c Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 9 Jul 2022 20:16:13 -0700 Subject: [PATCH 2/2] bump Functors --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 71790c78..91029e09 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Functors = "0.2.8, 0.3" +Functors = "0.3" Zygote = "0.6.40" julia = "1.6"