Skip to content

Transparent handling of tied weights #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.2.8, 0.3"
Functors = "0.3"
Zygote = "0.6.40"
julia = "1.6"

Expand Down
70 changes: 57 additions & 13 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,91 @@ const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}

abstract type AbstractRule end

struct Leaf{R,S}
mutable struct Leaf{R,S}
rule::R
state::S
end
Comment on lines +9 to 12
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this is mutable, then update!(tree, model, grad) can be guaranteed to alter the state tree in place. This opens the possibility of simplifying the interface, and never returning multiple things whose order you have to remember.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function setup(rule, x; seen = Base.IdSet())
function setup(rule, x; cache = IdDict{Any,Leaf}())
rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup)
if isnumeric(x)
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
isbits(x) || push!(seen, x)
return Leaf(rule, init(rule, x))
leaf = get(cache, x, missing)
ismissing(leaf) || return leaf
leaf = Leaf(rule, init(rule, x))
isbits(x) || (cache[x] = leaf)
return leaf
elseif isleaf(x)
return nothing
else
return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x))
return map(xᵢ -> setup(rule, xᵢ; cache), _trainable(x))
end
end

_add!(x, x̄) = iswriteable(x) ? (x .= x .+ x̄) : eltype(x).(x .+ x̄)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is what we want. We should never ever mutate a gradient, but I think we can just call @lazy x̄old + x̄new and lazily accumulate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My worry with the lazy accumulation approach is threefold. First, it blows any chance of making this type stable out the window. Secondly, it's possible the lazy Broadcasted may be evaluated multiple times as it passes through a chain of rules and thus incur accumulation overhead more than once. Lastly, complicated broadcasts come with a lot of compilation latency (especially on GPU) and I'm wary of making optimizers worse than they already are on that front.

Copy link
Member

@mcabbott mcabbott Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it eagerly to avoid this. But we cannot mutate the gradients, as they may be shared with others (e.g. from the rule for +).

Lazy .+ is almost free, it's very difficult to picture evaluating this twice ever costing as much as a copy. Not sure about compile times.

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about aliased gradients. If this is a correctness issue, we don't have much of a choice :)

subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)

update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
update!(::Nothing, x, x̄s...) = nothing, x

update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x
function update!(ℓ::Leaf, x, x̄s...)
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, map(base, x̄s)...)
return ℓ, subtract!(x, x̄′)
end

update!(tree, x, ::Zero, ::Zero...) = tree, x
function update!(tree, x, x̄s...)
cache = IdDict{Leaf,Any}()
_accumulate!(cache, tree, x, x̄s...)
return UpdateCallback(cache, IdDict{Leaf,Any}())(tree, x, x̄s...)
end

_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, _...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, ::Zero, ::Zero...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, ℓ::Leaf, _, ::Zero, ::Zero...) = nothing
_accumulate!(::AbstractDict{Leaf,Any}, _, _, ::Zero, ::Zero...) = nothing
Comment on lines +48 to +51
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of overloads with high degrees of overlap across multiple functions. I couldn't think of a way to deduplicate some of them, so if anyone has ideas that would be swell.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be just 4 methods, if the state tree has () instead of nothing, as in #106.

I also think it would be clearer to write variable names more often, not _, since 5 arguments is quite a few to count.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I tried that, but ran into ambiguities. This is the smallest number of methods I could come up with that didn't have ambiguities. If you can narrow that down, that would be superb.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underscores are mostly to appease the linter and possibly improve latency(??) Perhaps ::Any would work better, though I'm not sure that addresses your point about clarity?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how names can affect latency. I just mean they let your eye know what the 4th argument means, which ::Any doesn't help.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My impression was it implicitly acted as a @nospecialize, but looking at https://github.com/JuliaLang/julia/blob/98e1b13a7db5aa1d05b9a48085993375cf2298d0/src/method.c#L656 that may not be the case.


function _accumulate!(cache::AbstractDict{Leaf,Any}, ℓ::Leaf, _, x̄s...)
acc_x̄s = get(cache, ℓ, missing)
cache[ℓ] = ismissing(acc_x̄s) ? x̄s : map(_add!, acc_x̄s, x̄s)
return
end
function _accumulate!(cache::AbstractDict{Leaf,Any}, tree, x, x̄s...)
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
foreach((stᵢ, xᵢ, x̄sᵢ...) -> _accumulate!(cache, stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

# slightly cleaner way of closing over update! internal state
struct UpdateCallback
acc_grads::IdDict{Leaf,Any}
param_cache::IdDict{Leaf,Any}
end
Comment on lines +64 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The limitation might be mine but I have to say I find this struct really hard to read, compared to just closing over things which have one name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what "just closing over things which have one name." entails here, can you elaborate? Another reason for the struct over a normal closure is self-recursion, which I use here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean things like this, which define a dict & then use it:

   cache = IdDict{Leaf,Any}()
   _accumulate!(cache, tree, x, x̄s...)

With no further names: no structs, no field names.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall trying this first, and deciding to bundle things into a struct after seeing a lot of long, long lines from threading the two IdDicts through multiple levels of functions. It may also have been tricky to get that working in a backwards compatible way, but it's been long enough that I don't remember the whole context.


(::UpdateCallback)(::Nothing, x, x̄s...) = nothing, x
(::UpdateCallback)(::Nothing, x, ::Zero, ::Zero...) = nothing, x
(::UpdateCallback)(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x
(::UpdateCallback)(tree, x, ::Zero, ::Zero...) = tree, x

(cb::UpdateCallback)(ℓ::Leaf, x, x̄s...) = get!(cb.param_cache, ℓ) do
update!(ℓ, x, pop!(cb.acc_grads, ℓ)...)
end
function (cb::UpdateCallback)(tree, x, x̄s...)
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, re = functor(typeof(x), x)
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
map(first, xtree), re(map(last, xtree))
xtree = map(cb, tree, x′, x̄s′...)
return map(first, xtree), re(map(last, xtree))
end
Comment on lines +81 to 83
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This complication exists I think to reconstruct both the tree and the model on the way out of the recursion. But once Leaf is mutable, can't we skip that, and just mutate it? Just call fmap?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, absolutely. I held off from doing that here in case some user was stashing old state trees and would be blindsided by the values in those leaves suddenly changing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow. update! claimed it would mutate the states if it wanted to, and would typically alter arrays. (And update claimed not to, but had a bug.)

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if you had immutable arrays in your state tree before, the original state tree would be unchanged after update!. Perhaps we don't feel that was ever a solid guarantee (I don't), but we ought to get that point out in writing for posterity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I think we should change the doc for update! to be explicit that it now guarantees to update the state tree. (I thought the old one said "inputs are trash afterwards" but in fact it is explicit only about the model.)

update! need not in fact return two arguments, but whether that is too confusing to change (and to differ from update which must) is another question.


function update(tree, x, x̄s...)
t′ = fmap(copy, tree; exclude = iswriteable)
x′ = fmap(copy, x; exclude = iswriteable)
update!(t′, x′, x̄s...)
# because we rely on Leaf identity for tied parameters, they require special treatment
cache = IdDict()
tree′ = fmap(tree; cache, exclude = Base.Fix2(isa, Leaf)) do ℓ
Leaf(ℓ.rule, fmap(copy, ℓ.state; cache, exclude = iswriteable))
end
x′ = fmap(copy, x; cache = empty!(cache), exclude = iswriteable)
x̄s′ = fmap(copy, x̄s; cache = empty!(cache), exclude = iswriteable)
Comment on lines +88 to +92
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out we were not defensively copying state or gradients before, so they could still be mutated by a call to update.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine never to copy gradients. It's never safe to mutate them anyway, a rule which does so (or an rrule likewise) is simply a bug.

For copying state, can't we just say @functor Leaf (state,) and let fmap do it?

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For copying state, can't we just say @functor Leaf (state,) and let fmap do it?

That breaks Leaf identity, unfortunately. fmap will end up untying shared parameters by creating new leaves at each location during reconstruction.

Not defensively copying gradients seems fine though, good point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't fmap will preserve the Leaf identifications? That's what its cache is for.

Copy link
Member Author

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a recollection that it sometimes only preserved leaves, but re-reading the code you are correct.

return update!(tree′, x′, x̄s′...)
end

# default all rules to first order calls
Expand Down
79 changes: 71 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
g = ([25, 33],)
o = Descent(0.1)
s = Optimisers.setup(o, m)

s2, m2 = Optimisers.update(s, m, g)
@test m[1] == 1:2 # not mutated
@test Optimisers.iswriteable(m[1])
Expand Down Expand Up @@ -157,13 +157,76 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
end

@testset "tied weights" begin
ok = (1.0:3.0, sin, "abc", :abc)
m = (α = ok, β = rand(3), γ = ok)
m1 = (rand(3), m, rand(3))
@test Optimisers.setup(AdamW(), m1) isa Tuple
m2 = (rand(3), m, rand(3), m, rand(3)) # illegal
@test_throws ArgumentError Optimisers.setup(AdamW(), m2)
end
@testset "tuples" begin
twice = [1,2.0]
mtup = (twice, (copy(twice), twice)) # (tied (not tied, tied))

# simplest rule for which opt(g1) + opt(g2) != opt(g1 + g2)
stup = Optimisers.setup(Momentum(0.1), mtup)
gtup = ([3,3], ([10,10], [7,7])) # (g1, (g1 + g2, g2))

snew, mnew = Optimisers.update(stup, mtup, gtup)
@test mnew[1] ≈ mnew[2][1] # gradient was accumulated
@test mnew[2][2] === mnew[1] # and tie is not broken

st3, mt3 = Optimisers.update(stup, mtup, ([3,3], nothing))
@test mt3[1] ≈ [1,2] - 0.1 * [3,3]
@test mt3[2][2] === mt3[1]

st4, mt4 = Optimisers.update(stup, mtup, (nothing, ([5,5], [7,7])))
@test mt4[1] ≈ [1,2] - 0.1 * [7,7]
end

@testset "named" begin
thrice = [3f0]
model = (a = (x = thrice, y = Float32[4,5,6], z = true), b = ((m = (0, 1, thrice),),), c = (x = Float32[7,8], y = thrice))
tree = Optimisers.setup(Momentum(0.1, 0.9), model)
@test model.a.x === model.b[1].m[3] == model.c.y

loss(x::Array) = sum(abs2, x)
loss(x::Number) = x^3
loss(m) = sum(2 * loss(x) for x in m)
gradient(loss, model)
_, m2 = Optimisers.update(tree, model, gradient(loss, model)...)
@test m2.a.x === m2.b[1].m[3] == m2.c.y

loss3(m) = sum(x isa Tuple ? 0 : 2 * loss(x) for x in m)
gradient(loss3, model) # truncates the b limb
_, m3 = Optimisers.update(tree, model, gradient(loss3, model)...)
@test m3.a.x === m3.b[1].m[3] == m3.c.y
end

@testset "transpose" begin
mat = [1 2 3; 4 5 6.0]
bidir = (m = mat, f = log, t = transpose(mat), v = [7, 8, 9.0])
bigrad, _ = gradient((m, x) -> sum(abs2, m.m * (m.f).(m.t*x .+ m.v)), bidir, [1, 0.1])
@test bigrad.t isa Matrix # not a Transpose, that's the point here

state = Optimisers.setup(Descent(0.1), bidir)
@test state.t.parent === state.m # successfully tied

s2, b2 = Optimisers.update(state, bidir, bigrad)
@test b2.t.parent === b2.m # tie restored
@test b2.m ≈ bidir.m - 0.1 * (bigrad.m + transpose(bigrad.t)) # grad accumulated

state = Optimisers.setup(OptimiserChain(ClipGrad(10), Descent(0.1), ClipGrad(10)), bidir)
s2, b2 = Optimisers.update(state, bidir, bigrad)
@test b2.t.parent === b2.m
@test b2.m ≈ bidir.m - 0.1 * clamp.((bigrad.m + transpose(bigrad.t)), -10, 10)

# Similar, but now "primary" field is the transposed one:
tri = (a = transpose(mat), b = mat, c = transpose(mat), d = 4.0)
trigrad = gradient(m -> sum(abs2, m.a * (m.b * (m.c * [0.1, 1] .+ m.d) .- m.d)), tri)[1]
stri = Optimisers.setup(Descent(0.1), tri)
s3, t3 = Optimisers.update(stri, tri, trigrad)
@test t3.a.parent === t3.b === t3.c.parent
@test t3.a ≈ tri.a - 0.1 * (trigrad.a + trigrad.b' + trigrad.c)

g4 = (a = Broadcast.broadcasted(+, mat', 1), b = nothing, c = @thunk(mat' .+ 1), d = nothing)
# Error: no constructors for type Any
@test_broken s4, t4 = Optimisers.update(stri, tri, g4)
end
end # tied weights

end
@testset verbose=true "Destructure" begin
Expand Down