@@ -10,10 +10,10 @@ abstract type AbstractRule end
10
10
# ## setup
11
11
# ##
12
12
13
- mutable struct Leaf{R,S}
13
+ mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
14
14
rule:: R
15
15
state:: S
16
- frozen:: Bool
16
+ frozen:: Bool # mutability also allows this flag to be changed
17
17
end
18
18
19
19
@functor Leaf
45
45
function update! (tree, model, grad)
46
46
# First walk is to accumulate the gradient. This recursion visits every copy of
47
47
# shared leaves, but stops when branches are absent from the gradient:
48
- gdict = IdDict {Leaf, Any} ()
49
- grads! (gdict, tree, model, grad)
50
- # Second walk is to update the model, using same fmap walk as setup:
51
- xdict = IdDict {Leaf, Any} () # (this exists to allow for shared ℓ without shared x)
52
- newmodel = fmap (model, tree; exclude = isnumeric) do x, ℓ
53
- ℓ isa Leaf || error (" this state does not match the model, expected a Leaf here" )
48
+ dict = IdDict {Leaf, Any} ()
49
+ grads! (dict, tree, model, grad)
50
+ # Second walk is to update the model. The walk taken follows Leaf identity
51
+ newmodel = fmap (tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk) do ℓ, x
54
52
ℓ. frozen && return x
55
- haskey (gdict, ℓ) || return x # no gradient seen, nothing to do
56
- if haskey (xdict, ℓ)
57
- # This means that shared ℓ encodes sharing not noted in x. Won't happen with setup above, no API yet.
58
- x′ = xdict[ℓ] # ... and is why xdict exists.
59
- size (x′) == size (x) || error (" the same Leaf belongs to arrays of size $(size (x)) and $(size (x′)) " )
60
- return x′
61
- end
62
- s′, x̄′ = apply! (ℓ. rule, ℓ. state, x, gdict[ℓ])
53
+ haskey (dict, ℓ) || return x # no gradient seen, nothing to do
54
+ s′, x̄′ = apply! (ℓ. rule, ℓ. state, x, dict[ℓ])
63
55
ℓ. state = s′ # to get state out of here, rely on mutability of Leaf
64
- xdict[ℓ] = subtract! (x, x̄′)
56
+ subtract! (x, x̄′)
65
57
end
66
58
tree, newmodel # note that tree is guaranteed to be updated
67
59
end
@@ -89,6 +81,13 @@ function update(tree, x, x̄s...)
89
81
update! (t′, x′, x̄s... )
90
82
end
91
83
84
+ # This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first
85
+ function _second_walk (f, x, y)
86
+ x′, _ = functor (typeof (y), x)
87
+ y′, re = functor (y)
88
+ re (map (f, x′, y′))
89
+ end
90
+
92
91
# default all rules to first order calls
93
92
apply! (o, state, x, dx, dx2, dxs... ) = apply! (o, state, x, dx)
94
93
0 commit comments