-
-
Notifications
You must be signed in to change notification settings - Fork 23
Transparent handling of tied weights #100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,47 +6,91 @@ const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor} | |
|
||
abstract type AbstractRule end | ||
|
||
struct Leaf{R,S} | ||
mutable struct Leaf{R,S} | ||
rule::R | ||
state::S | ||
end | ||
|
||
function setup(rule, x; seen = Base.IdSet()) | ||
function setup(rule, x; cache = IdDict{Any,Leaf}()) | ||
rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup) | ||
if isnumeric(x) | ||
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry.")) | ||
isbits(x) || push!(seen, x) | ||
return Leaf(rule, init(rule, x)) | ||
leaf = get(cache, x, missing) | ||
ismissing(leaf) || return leaf | ||
leaf = Leaf(rule, init(rule, x)) | ||
isbits(x) || (cache[x] = leaf) | ||
return leaf | ||
elseif isleaf(x) | ||
return nothing | ||
else | ||
return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x)) | ||
return map(xᵢ -> setup(rule, xᵢ; cache), _trainable(x)) | ||
end | ||
end | ||
|
||
_add!(x, x̄) = iswriteable(x) ? (x .= x .+ x̄) : eltype(x).(x .+ x̄) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this is what we want. We should never ever mutate a gradient, but I think we can just call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My worry with the lazy accumulation approach is threefold. First, it blows any chance of making this type stable out the window. Secondly, it's possible the lazy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do it eagerly to avoid this. But we cannot mutate the gradients, as they may be shared with others (e.g. from the rule for +). Lazy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point about aliased gradients. If this is a correctness issue, we don't have much of a choice :) |
||
subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) | ||
|
||
update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x | ||
update!(::Nothing, x, x̄s...) = nothing, x | ||
|
||
update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x | ||
function update!(ℓ::Leaf, x, x̄s...) | ||
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...) | ||
Leaf(ℓ.rule, s′), subtract!(x, x̄′) | ||
ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, map(base, x̄s)...) | ||
return ℓ, subtract!(x, x̄′) | ||
end | ||
|
||
update!(tree, x, ::Zero, ::Zero...) = tree, x | ||
function update!(tree, x, x̄s...) | ||
cache = IdDict{Leaf,Any}() | ||
_accumulate!(cache, tree, x, x̄s...) | ||
return UpdateCallback(cache, IdDict{Leaf,Any}())(tree, x, x̄s...) | ||
end | ||
|
||
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, _...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, ::Zero, ::Zero...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, ℓ::Leaf, _, ::Zero, ::Zero...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, _, _, ::Zero, ::Zero...) = nothing | ||
Comment on lines
+48
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a lot of overloads with high degrees of overlap across multiple functions. I couldn't think of a way to deduplicate some of them, so if anyone has ideas that would be swell. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it can be just 4 methods, if the state tree has () instead of nothing, as in #106. I also think it would be clearer to write variable names more often, not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I tried that, but ran into ambiguities. This is the smallest number of methods I could come up with that didn't have ambiguities. If you can narrow that down, that would be superb. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The underscores are mostly to appease the linter and possibly improve latency(??) Perhaps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see how names can affect latency. I just mean they let your eye know what the 4th argument means, which There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My impression was it implicitly acted as a |
||
|
||
function _accumulate!(cache::AbstractDict{Leaf,Any}, ℓ::Leaf, _, x̄s...) | ||
acc_x̄s = get(cache, ℓ, missing) | ||
cache[ℓ] = ismissing(acc_x̄s) ? x̄s : map(_add!, acc_x̄s, x̄s) | ||
return | ||
end | ||
function _accumulate!(cache::AbstractDict{Leaf,Any}, tree, x, x̄s...) | ||
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) | ||
x′, _ = functor(typeof(x), x) | ||
foreach((stᵢ, xᵢ, x̄sᵢ...) -> _accumulate!(cache, stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) | ||
end | ||
|
||
# slightly cleaner way of closing over update! internal state | ||
struct UpdateCallback | ||
acc_grads::IdDict{Leaf,Any} | ||
param_cache::IdDict{Leaf,Any} | ||
end | ||
Comment on lines
+64
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The limitation might be mine but I have to say I find this struct really hard to read, compared to just closing over things which have one name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what "just closing over things which have one name." entails here, can you elaborate? Another reason for the struct over a normal closure is self-recursion, which I use here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean things like this, which define a dict & then use it:
With no further names: no structs, no field names. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recall trying this first, and deciding to bundle things into a struct after seeing a lot of long, long lines from threading the two IdDicts through multiple levels of functions. It may also have been tricky to get that working in a backwards compatible way, but it's been long enough that I don't remember the whole context. |
||
|
||
(::UpdateCallback)(::Nothing, x, x̄s...) = nothing, x | ||
(::UpdateCallback)(::Nothing, x, ::Zero, ::Zero...) = nothing, x | ||
(::UpdateCallback)(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x | ||
(::UpdateCallback)(tree, x, ::Zero, ::Zero...) = tree, x | ||
|
||
(cb::UpdateCallback)(ℓ::Leaf, x, x̄s...) = get!(cb.param_cache, ℓ) do | ||
update!(ℓ, x, pop!(cb.acc_grads, ℓ)...) | ||
end | ||
function (cb::UpdateCallback)(tree, x, x̄s...) | ||
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) | ||
x′, re = functor(typeof(x), x) | ||
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) | ||
map(first, xtree), re(map(last, xtree)) | ||
xtree = map(cb, tree, x′, x̄s′...) | ||
return map(first, xtree), re(map(last, xtree)) | ||
end | ||
Comment on lines
+81
to
83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This complication exists I think to reconstruct both the tree and the model on the way out of the recursion. But once There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, absolutely. I held off from doing that here in case some user was stashing old state trees and would be blindsided by the values in those leaves suddenly changing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I follow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But if you had immutable arrays in your state tree before, the original state tree would be unchanged after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I think we should change the doc for
|
||
|
||
function update(tree, x, x̄s...) | ||
t′ = fmap(copy, tree; exclude = iswriteable) | ||
x′ = fmap(copy, x; exclude = iswriteable) | ||
update!(t′, x′, x̄s...) | ||
# because we rely on Leaf identity for tied parameters, they require special treatment | ||
cache = IdDict() | ||
tree′ = fmap(tree; cache, exclude = Base.Fix2(isa, Leaf)) do ℓ | ||
Leaf(ℓ.rule, fmap(copy, ℓ.state; cache, exclude = iswriteable)) | ||
end | ||
x′ = fmap(copy, x; cache = empty!(cache), exclude = iswriteable) | ||
x̄s′ = fmap(copy, x̄s; cache = empty!(cache), exclude = iswriteable) | ||
Comment on lines
+88
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns out we were not defensively copying state or gradients before, so they could still be mutated by a call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems fine never to copy gradients. It's never safe to mutate them anyway, a rule which does so (or an For copying state, can't we just say There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That breaks Not defensively copying gradients seems fine though, good point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had a recollection that it sometimes only preserved leaves, but re-reading the code you are correct. |
||
return update!(tree′, x′, x̄s′...) | ||
end | ||
|
||
# default all rules to first order calls | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once this is mutable, then
update!(tree, model, grad)
can be guaranteed to alter the state tree in place. This opens the possibility of simplifying the interface, and never returning multiple things whose order you have to remember.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/FluxML/Optimisers.jl/pull/100/files#r956755677.