|
1 | 1 |
|
2 |
| -using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero |
| 2 | +using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent |
3 | 3 | base(dx::Tangent) = backing(canonicalize(dx))
|
4 | 4 | base(dx) = dx
|
5 | 5 | const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
|
6 | 6 |
|
7 | 7 | abstract type AbstractRule end
|
8 | 8 |
|
9 |
| -struct Leaf{R,S} |
| 9 | +### |
| 10 | +### setup |
| 11 | +### |
| 12 | + |
| 13 | +mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing |
10 | 14 | rule::R
|
11 | 15 | state::S
|
12 | 16 | end
|
13 | 17 |
|
14 |
| -function setup(rule, x; seen = Base.IdSet()) |
15 |
| - rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup) |
| 18 | +@functor Leaf |
| 19 | + |
| 20 | +Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) |
| 21 | + |
| 22 | +function setup(rule::AbstractRule, model) |
| 23 | + cache = IdDict() |
| 24 | + tree = _setup(rule, model; cache) |
| 25 | + isempty(cache) && @warn "setup found no trainable parameters in this model" |
| 26 | + tree |
| 27 | +end |
| 28 | + |
| 29 | +# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. |
| 30 | +function _setup(rule, x; cache) |
| 31 | + haskey(cache, x) && return cache[x] |
16 | 32 | if isnumeric(x)
|
17 |
| - x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry.")) |
18 |
| - isbits(x) || push!(seen, x) |
19 |
| - return Leaf(rule, init(rule, x)) |
20 |
| - elseif isleaf(x) |
21 |
| - return nothing |
| 33 | + ℓ = Leaf(rule, init(rule, x)) |
| 34 | + if isbits(x) |
| 35 | + cache[nothing] = nothing # just to disable the warning |
| 36 | + ℓ |
| 37 | + else |
| 38 | + cache[x] = ℓ |
| 39 | + end |
22 | 40 | else
|
23 |
| - return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x)) |
| 41 | + map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) |
24 | 42 | end
|
25 | 43 | end
|
26 | 44 |
|
27 |
| -subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) |
| 45 | +function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! |
| 46 | + ioc = IOContext(io, :compact => true) |
| 47 | + print(ioc, "Leaf(", ℓ.rule, ", ") |
| 48 | + show(ioc, ℓ.state) |
| 49 | + print(ioc, ")") |
| 50 | +end |
28 | 51 |
|
29 |
| -update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x |
30 |
| -update!(::Nothing, x, x̄s...) = nothing, x |
| 52 | +### |
| 53 | +### update |
| 54 | +### |
31 | 55 |
|
32 |
| -update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x |
33 |
| -function update!(ℓ::Leaf, x, x̄s...) |
34 |
| - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...) |
35 |
| - Leaf(ℓ.rule, s′), subtract!(x, x̄′) |
| 56 | +function update(tree, model, grad, higher...) |
| 57 | + t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf |
| 58 | + x′ = fmap(copy, model; exclude = maywrite) |
| 59 | + update!(t′, x′, grad, higher...) |
36 | 60 | end
|
37 | 61 |
|
38 |
| -update!(tree, x, ::Zero, ::Zero...) = tree, x |
39 |
| -function update!(tree, x, x̄s...) |
40 |
| - x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) |
41 |
| - x′, re = functor(typeof(x), x) |
42 |
| - xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
43 |
| - map(first, xtree), re(map(last, xtree)) |
| 62 | +function update!(tree, model, grad, higher...) |
| 63 | + # First walk is to accumulate the gradient. This recursion visits every copy of |
| 64 | + # shared leaves, but stops when branches are absent from the gradient: |
| 65 | + grads = IdDict{Leaf, Any}() |
| 66 | + _grads!(grads, tree, model, grad, higher...) |
| 67 | + # Second walk is to update the model. The params cache indexed by (tree,x), |
| 68 | + # so that identified Leafs can tie isbits parameters, but setup won't do that for you: |
| 69 | + newmodel = _update!(tree, model; grads, params = IdDict()) |
| 70 | + tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree. |
| 71 | +end |
| 72 | + |
| 73 | +function _update!(tree, x; grads, params) |
| 74 | + haskey(params, (tree,x)) && return params[(tree,x)] |
| 75 | + isbits(tree) && return x # means () is not cached, and also (((),),) |
| 76 | + x′, re = functor(x) |
| 77 | + x′′ = re(map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′)) |
| 78 | + if ismutable(x′′) |
| 79 | + params[(tree,x)] = x′′ |
| 80 | + else # no ties to preserve between immutable structs, right? |
| 81 | + x′′ |
| 82 | + end |
44 | 83 | end
|
| 84 | +function _update!(ℓ::Leaf, x; grads, params) |
| 85 | + haskey(params, (ℓ,x)) && return params[(ℓ,x)] |
| 86 | + params[(ℓ,x)] = if haskey(grads, ℓ) |
| 87 | + ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) |
| 88 | + subtract!(x, x̄′) |
| 89 | + else |
| 90 | + x # no gradient seen |
| 91 | + end |
| 92 | +end |
| 93 | + |
| 94 | +subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) |
45 | 95 |
|
46 |
| -function update(tree, x, x̄s...) |
47 |
| - t′ = fmap(copy, tree; exclude = maywrite) |
48 |
| - x′ = fmap(copy, x; exclude = maywrite) |
49 |
| - update!(t′, x′, x̄s...) |
| 96 | +_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing |
| 97 | +function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...) |
| 98 | + x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s)) |
| 99 | + dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. |
| 100 | + nothing |
| 101 | +end |
| 102 | +_grads!(dict::IdDict, t, x, ::Zero...) = nothing |
| 103 | +function _grads!(dict::IdDict, tree, x, x̄s...) |
| 104 | + # The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from |
| 105 | + # functor(typeof(tree), base(x̄)), for things like Transpose |
| 106 | + x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) |
| 107 | + x′, _ = functor(typeof(x), x) |
| 108 | + foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
50 | 109 | end
|
51 | 110 |
|
52 | 111 | # default all rules to first order calls
|
53 | 112 | apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)
|
54 | 113 |
|
| 114 | +### |
| 115 | +### sources of truth |
| 116 | +### |
| 117 | + |
55 | 118 | """
|
56 | 119 | isnumeric(x) -> Bool
|
57 | 120 |
|
@@ -98,8 +161,12 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu
|
98 | 161 | map(c -> c in tr ? c : nothing, ch)
|
99 | 162 | end
|
100 | 163 |
|
| 164 | +### |
| 165 | +### rule definition helpers |
| 166 | +### |
| 167 | + |
101 | 168 | """
|
102 |
| - @.. x = x + y |
| 169 | + @.. x = y + z |
103 | 170 |
|
104 | 171 | Sometimes in-place broadcasting macro, for use in `apply!` rules.
|
105 | 172 | If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
|
@@ -135,11 +202,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
|
135 | 202 |
|
136 | 203 | onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
|
137 | 204 | onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)
|
138 |
| - |
139 |
| -function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! |
140 |
| - ioc = IOContext(io, :compact => true) |
141 |
| - print(ioc, "Leaf(", ℓ.rule, ", ") |
142 |
| - show(ioc, ℓ.state) |
143 |
| - print(io, ")") |
144 |
| -end |
145 |
| - |
|
0 commit comments