@@ -48,7 +48,7 @@ function update!(tree, model, grad)
48
48
dict = IdDict {Leaf, Any} ()
49
49
grads! (dict, tree, model, grad)
50
50
# Second walk is to update the model. The walk taken follows Leaf identity
51
- newmodel = fmap (tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk) do ℓ, x
51
+ newmodel = fmap (tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk, cache = LeafCache () ) do ℓ, x
52
52
ℓ. frozen && return x
53
53
haskey (dict, ℓ) || return x # no gradient seen, nothing to do
54
54
s′, x̄′ = apply! (ℓ. rule, ℓ. state, x, dict[ℓ])
@@ -88,6 +88,21 @@ function _second_walk(f, x, y)
88
88
re (map (f, x′, y′))
89
89
end
90
90
91
+ # When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state.
92
+ # This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work.
93
+ struct LeafCache <: AbstractDict{Leaf,Any}
94
+ dict:: IdDict{Leaf,Any}
95
+ end
96
+ LeafCache () = LeafCache (IdDict {Leaf,Any} ())
97
+
98
+ Base. setindex! (c:: LeafCache , x, ℓ:: Leaf ) = setindex! (c. dict, x, ℓ)
99
+ Base. setindex! (c:: LeafCache , x, _) = nothing
100
+ Base. in (k, c:: LeafCache ) = k in c. dict
101
+ Base. haskey (c:: LeafCache , k) = haskey (c. dict, k)
102
+ Base. getindex (c:: LeafCache , ℓ:: Leaf ) = getindex (c. dict, ℓ)
103
+ Base. iterate (c:: LeafCache , i = 0 ) = iterate (c. dict, i)
104
+ Base. length (c:: LeafCache ) = length (c. dict)
105
+
91
106
# default all rules to first order calls
92
107
apply! (o, state, x, dx, dxs... ) = apply! (o, state, x, dx)
93
108
0 commit comments