|
20 | 20 | Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b)
|
21 | 21 |
|
22 | 22 | function setup(rule::AbstractRule, model)
|
23 |
| - cnt = Ref(0) |
24 |
| - # Rely on Functors to identify shared arrays, they will share a Leaf in this tree: |
25 |
| - tree = fmapstructure(model, exclude = isnumeric) do x |
26 |
| - cnt[] += 1 |
27 |
| - Leaf(rule, init(rule, x)) |
28 |
| - end |
29 |
| - cnt[] == 0 && @warn "setup found no parameters in the given model" |
| 23 | + cache = IdDict() |
| 24 | + tree = _setup(rule, model; cache) |
| 25 | + isempty(cache) && @warn "setup found no trainable parameters in this model" |
30 | 26 | tree
|
31 | 27 | end
|
32 | 28 |
|
| 29 | +# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. |
| 30 | +function _setup(rule, x; cache) |
| 31 | + haskey(cache, x) && return cache[x] |
| 32 | + if isnumeric(x) |
| 33 | + ℓ = Leaf(rule, init(rule, x)) |
| 34 | + if isbits(x) |
| 35 | + cache[nothing] = nothing # just to disable the warning |
| 36 | + ℓ |
| 37 | + else |
| 38 | + cache[x] = ℓ |
| 39 | + end |
| 40 | + else |
| 41 | + map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) |
| 42 | + end |
| 43 | +end |
| 44 | + |
33 | 45 | function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
|
34 | 46 | ioc = IOContext(io, :compact => true)
|
35 | 47 | print(ioc, "Leaf(", ℓ.rule, ", ")
|
|
41 | 53 | ### update
|
42 | 54 | ###
|
43 | 55 |
|
44 |
| -function update!(tree, model, grad) |
| 56 | +function update(tree, model, grad, higher...) |
| 57 | + t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf |
| 58 | + x′ = fmap(copy, model; exclude = maywrite) |
| 59 | + update!(t′, x′, grad, higher...) |
| 60 | +end |
| 61 | + |
| 62 | +function update!(tree, model, grad, higher...) |
45 | 63 | # First walk is to accumulate the gradient. This recursion visits every copy of
|
46 | 64 | # shared leaves, but stops when branches are absent from the gradient:
|
47 |
| - dict = IdDict{Leaf, Any}() |
48 |
| - grads!(dict, tree, model, grad) |
49 |
| - # Second walk is to update the model. The walk taken follows Leaf identity |
50 |
| - newmodel = fmap(tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x |
51 |
| - haskey(dict, ℓ) || return x # no gradient seen, nothing to do |
52 |
| - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) |
53 |
| - ℓ.state = s′ # to get state out of here, rely on mutability of Leaf |
| 65 | + grads = IdDict{Leaf, Any}() |
| 66 | + _grads!(grads, tree, model, grad, higher...) |
| 67 | + # Second walk is to update the model. The params cache indexed by (tree,x), |
| 68 | + # so that identified Leafs can tie isbits parameters, but setup won't do that for you: |
| 69 | + newmodel = _update!(tree, model; grads, params = IdDict()) |
| 70 | + tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree. |
| 71 | +end |
| 72 | + |
| 73 | +function _update!(tree, x; grads, params) |
| 74 | + haskey(params, (tree,x)) && return params[(tree,x)] |
| 75 | + isbits(tree) && return x # means () is not cached, and also (((),),) |
| 76 | + x′, re = functor(x) |
| 77 | + x′′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′) |
| 78 | + params[(tree,x)] = re(x′′) |
| 79 | +end |
| 80 | +function _update!(ℓ::Leaf, x; grads, params) |
| 81 | + haskey(params, (ℓ,x)) && return params[(ℓ,x)] |
| 82 | + params[(ℓ,x)] = if haskey(grads, ℓ) |
| 83 | + ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) |
54 | 84 | subtract!(x, x̄′)
|
| 85 | + else |
| 86 | + x # no gradient seen |
55 | 87 | end
|
56 |
| - tree, newmodel # note that tree is guaranteed to be updated |
57 | 88 | end
|
58 | 89 |
|
59 | 90 | subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
|
60 | 91 |
|
61 |
| -grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing |
62 |
| -function grads!(dict::IdDict, ℓ::Leaf, x, x̄) |
63 |
| - x̄₀ = get(dict, ℓ, ZeroTangent()) |
64 |
| - dict[ℓ] = x̄ + x̄₀ # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. |
| 92 | +_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing |
| 93 | +function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...) |
| 94 | + x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s)) |
| 95 | + dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. |
65 | 96 | nothing
|
66 | 97 | end
|
67 |
| -grads!(dict::IdDict, t, x, ::Zero) = nothing |
68 |
| -function grads!(dict::IdDict, tree, x, x̄s...) |
69 |
| - # The only reason grads! takes model is that functor(typeof(x), base(x̄)) may differ from |
| 98 | +_grads!(dict::IdDict, t, x, ::Zero...) = nothing |
| 99 | +function _grads!(dict::IdDict, tree, x, x̄s...) |
| 100 | + # The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from |
70 | 101 | # functor(typeof(tree), base(x̄)), for things like Transpose
|
71 | 102 | x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
|
72 | 103 | x′, _ = functor(typeof(x), x)
|
73 |
| - foreach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
74 |
| -end |
75 |
| - |
76 |
| -function update(tree, x, x̄s...) |
77 |
| - t′ = fmap(copy, tree; exclude = maywrite) # goes inside Leaf |
78 |
| - x′ = fmap(copy, x; exclude = maywrite) |
79 |
| - update!(t′, x′, x̄s...) |
80 |
| -end |
81 |
| - |
82 |
| -# This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first |
83 |
| -function _second_walk(f, x, y) |
84 |
| - x′, _ = functor(typeof(y), x) |
85 |
| - y′, re = functor(y) |
86 |
| - re(map(f, x′, y′)) |
87 |
| -end |
88 |
| - |
89 |
| -# When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state. |
90 |
| -# This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work. |
91 |
| -struct LeafCache <: AbstractDict{Leaf,Any} |
92 |
| - dict::IdDict{Leaf,Any} |
| 104 | + foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
93 | 105 | end
|
94 |
| -LeafCache() = LeafCache(IdDict{Leaf,Any}()) |
95 |
| - |
96 |
| -Base.setindex!(c::LeafCache, x, ℓ::Leaf) = setindex!(c.dict, x, ℓ) |
97 |
| -Base.setindex!(c::LeafCache, x, _) = nothing |
98 |
| -Base.in(k, c::LeafCache) = k in c.dict |
99 |
| -Base.haskey(c::LeafCache, k) = haskey(c.dict, k) |
100 |
| -Base.getindex(c::LeafCache, ℓ::Leaf) = getindex(c.dict, ℓ) |
101 |
| -Base.iterate(c::LeafCache, i = 0) = iterate(c.dict, i) |
102 |
| -Base.length(c::LeafCache) = length(c.dict) |
103 | 106 |
|
104 | 107 | # default all rules to first order calls
|
105 | 108 | apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)
|
|
0 commit comments