Skip to content

Commit 670e49a

Browse files
committed
a tidier idea, just replace _default_walk
1 parent 64d5d9f commit 670e49a

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

src/interface.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ abstract type AbstractRule end
1010
### setup
1111
###
1212

13-
mutable struct Leaf{R,S}
13+
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
1414
rule::R
1515
state::S
16-
frozen::Bool
16+
frozen::Bool # mutability also allows this flag to be changed
1717
end
1818

1919
@functor Leaf
@@ -45,23 +45,15 @@ end
4545
function update!(tree, model, grad)
4646
# First walk is to accumulate the gradient. This recursion visits every copy of
4747
# shared leaves, but stops when branches are absent from the gradient:
48-
gdict = IdDict{Leaf, Any}()
49-
grads!(gdict, tree, model, grad)
50-
# Second walk is to update the model, using same fmap walk as setup:
51-
xdict = IdDict{Leaf, Any}() # (this exists to allow for shared ℓ without shared x)
52-
newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ
53-
isa Leaf || error("this state does not match the model, expected a Leaf here")
48+
dict = IdDict{Leaf, Any}()
49+
grads!(dict, tree, model, grad)
50+
# Second walk is to update the model. The walk taken follows Leaf identity
51+
newmodel = fmap(tree, model; exclude =->isa Leaf, walk = _second_walk) do ℓ, x
5452
.frozen && return x
55-
haskey(gdict, ℓ) || return x # no gradient seen, nothing to do
56-
if haskey(xdict, ℓ)
57-
# This means that shared ℓ encodes sharing not noted in x. Won't happen with setup above, no API yet.
58-
x′ = xdict[ℓ] # ... and is why xdict exists.
59-
size(x′) == size(x) || error("the same Leaf belongs to arrays of size $(size(x)) and $(size(x′))")
60-
return x′
61-
end
62-
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, gdict[ℓ])
53+
haskey(dict, ℓ) || return x # no gradient seen, nothing to do
54+
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
6355
.state = s′ # to get state out of here, rely on mutability of Leaf
64-
xdict[ℓ] = subtract!(x, x̄′)
56+
subtract!(x, x̄′)
6557
end
6658
tree, newmodel # note that tree is guaranteed to be updated
6759
end
@@ -89,6 +81,13 @@ function update(tree, x, x̄s...)
8981
update!(t′, x′, x̄s...)
9082
end
9183

84+
# This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first
85+
function _second_walk(f, x, y)
86+
x′, _ = functor(typeof(y), x)
87+
y′, re = functor(y)
88+
re(map(f, x′, y′))
89+
end
90+
9291
# default all rules to first order calls
9392
apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)
9493

0 commit comments

Comments
 (0)