|
1 | 1 |
|
2 |
| -using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero |
| 2 | +using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent |
3 | 3 | base(dx::Tangent) = backing(canonicalize(dx))
|
4 | 4 | base(dx) = dx
|
5 | 5 | const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
|
6 | 6 |
|
7 | 7 | abstract type AbstractRule end
|
8 | 8 |
|
9 |
| -struct Leaf{R,S} |
| 9 | +### |
| 10 | +### setup |
| 11 | +### |
| 12 | + |
| 13 | +mutable struct Leaf{R,S} |
10 | 14 | rule::R
|
11 | 15 | state::S
|
| 16 | + frozen::Bool |
12 | 17 | end
|
13 | 18 |
|
14 |
| -function setup(rule, x; seen = Base.IdSet()) |
15 |
| - rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup) |
16 |
| - if isnumeric(x) |
17 |
| - x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry.")) |
18 |
| - isbits(x) || push!(seen, x) |
19 |
| - return Leaf(rule, init(rule, x)) |
20 |
| - elseif isleaf(x) |
21 |
| - return nothing |
22 |
| - else |
23 |
| - return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x)) |
| 19 | +@functor Leaf |
| 20 | + |
| 21 | +Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) |
| 22 | + |
| 23 | +function setup(rule::AbstractRule, model) |
| 24 | + cnt = Ref(0) |
| 25 | + # Rely on Functors to identify shared arrays, they will share a Leaf in this tree: |
| 26 | + tree = fmapstructure(model, exclude = isnumeric) do x |
| 27 | + cnt[] += 1 |
| 28 | + Leaf(rule, init(rule, x), false) |
24 | 29 | end
|
| 30 | + cnt[] == 0 && @warn "setup found no parameters in the given model" |
| 31 | + tree |
25 | 32 | end
|
26 | 33 |
|
27 |
| -subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) |
28 |
| - |
29 |
| -update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x |
30 |
| -update!(::Nothing, x, x̄s...) = nothing, x |
| 34 | +function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! |
| 35 | + ioc = IOContext(io, :compact => true) |
| 36 | + print(ioc, "Leaf(", ℓ.rule, ", ") |
| 37 | + show(ioc, ℓ.state) |
| 38 | + print(ioc, ", ", ℓ.frozen, ")") |
| 39 | +end |
31 | 40 |
|
32 |
| -update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x |
33 |
| -function update!(ℓ::Leaf, x, x̄s...) |
34 |
| - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...) |
35 |
| - Leaf(ℓ.rule, s′), subtract!(x, x̄′) |
| 41 | +### |
| 42 | +### update |
| 43 | +### |
| 44 | + |
| 45 | +function update!(tree, model, grad) |
| 46 | + # First walk is to accumulate the gradient. This recursion visits every copy of |
| 47 | + # shared leaves, but stops when branches are absent from the gradient: |
| 48 | + dict = IdDict{Leaf, Any}() |
| 49 | + grads!(dict, tree, model, grad) |
| 50 | + # Second walk is to update the model, using same fmap walk as setup, thus each Leaf exactly once: |
| 51 | + newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ |
| 52 | + ℓ isa Leaf || error("this state does not match the model, expected a Leaf here") |
| 53 | + ℓ.frozen && return x |
| 54 | + haskey(dict, ℓ) || return x |
| 55 | + s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) |
| 56 | + ℓ.state = s′ # to get state out of here, rely on mutability of Leaf |
| 57 | + subtract!(x, x̄′) |
| 58 | + end |
| 59 | + tree, newmodel # note that tree is guaranteed to be updated |
36 | 60 | end
|
37 | 61 |
|
38 |
| -update!(tree, x, ::Zero, ::Zero...) = tree, x |
39 |
| -function update!(tree, x, x̄s...) |
| 62 | +subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) |
| 63 | + |
| 64 | +grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing |
| 65 | +function grads!(dict::IdDict, ℓ::Leaf, x, x̄) |
| 66 | + x̄₀ = get(dict, ℓ, false) |
| 67 | + dict[ℓ] = Broadcast.broadcasted(+, x̄, x̄₀) |
| 68 | + nothing |
| 69 | +end |
| 70 | +grads!(dict::IdDict, t, x, ::Zero) = nothing |
| 71 | +function grads!(dict::IdDict, tree, x, x̄s...) |
| 72 | + # The only reason grads! takes model is that functor(typeof(x), base(x̄)) may differ from |
| 73 | + # functor(typeof(tree), base(x̄)), for things like Transpose |
40 | 74 | x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
|
41 |
| - x′, re = functor(typeof(x), x) |
42 |
| - xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
43 |
| - map(first, xtree), re(map(last, xtree)) |
| 75 | + x′, _ = functor(typeof(x), x) |
| 76 | + foreach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
44 | 77 | end
|
45 | 78 |
|
46 | 79 | function update(tree, x, x̄s...)
|
47 |
| - t′ = fmap(copy, tree; exclude = maywrite) |
| 80 | + t′ = fmap(copy, tree; exclude = maywrite) # goes inside Leaf |
48 | 81 | x′ = fmap(copy, x; exclude = maywrite)
|
49 | 82 | update!(t′, x′, x̄s...)
|
50 | 83 | end
|
51 | 84 |
|
52 | 85 | # default all rules to first order calls
|
53 | 86 | apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)
|
54 | 87 |
|
| 88 | +### |
| 89 | +### sources of truth |
| 90 | +### |
| 91 | + |
55 | 92 | """
|
56 | 93 | isnumeric(x) -> Bool
|
57 | 94 |
|
@@ -98,6 +135,10 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu
|
98 | 135 | map(c -> c in tr ? c : nothing, ch)
|
99 | 136 | end
|
100 | 137 |
|
| 138 | +### |
| 139 | +### rule definition helpers |
| 140 | +### |
| 141 | + |
101 | 142 | """
|
102 | 143 | @.. x = x + y
|
103 | 144 |
|
@@ -135,11 +176,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
|
135 | 176 |
|
136 | 177 | onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
|
137 | 178 | onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)
|
138 |
| - |
139 |
| -function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! |
140 |
| - ioc = IOContext(io, :compact => true) |
141 |
| - print(ioc, "Leaf(", ℓ.rule, ", ") |
142 |
| - show(ioc, ℓ.state) |
143 |
| - print(io, ")") |
144 |
| -end |
145 |
| - |
|
0 commit comments