Skip to content

Commit e17e474

Browse files
committed
add a LeafCache type, to make fmap ignore () singleton
1 parent 0de29e1 commit e17e474

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

src/interface.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function update!(tree, model, grad)
4848
dict = IdDict{Leaf, Any}()
4949
grads!(dict, tree, model, grad)
5050
# Second walk is to update the model. The walk taken follows Leaf identity
51-
newmodel = fmap(tree, model; exclude =->isa Leaf, walk = _second_walk) do ℓ, x
51+
newmodel = fmap(tree, model; exclude =->isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x
5252
.frozen && return x
5353
haskey(dict, ℓ) || return x # no gradient seen, nothing to do
5454
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
@@ -88,6 +88,21 @@ function _second_walk(f, x, y)
8888
re(map(f, x′, y′))
8989
end
9090

91+
# When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state.
92+
# This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work.
93+
struct LeafCache <: AbstractDict{Leaf,Any}
94+
dict::IdDict{Leaf,Any}
95+
end
96+
LeafCache() = LeafCache(IdDict{Leaf,Any}())
97+
98+
Base.setindex!(c::LeafCache, x, ℓ::Leaf) = setindex!(c.dict, x, ℓ)
99+
Base.setindex!(c::LeafCache, x, _) = nothing
100+
Base.in(k, c::LeafCache) = k in c.dict
101+
Base.haskey(c::LeafCache, k) = haskey(c.dict, k)
102+
Base.getindex(c::LeafCache, ℓ::Leaf) = getindex(c.dict, ℓ)
103+
Base.iterate(c::LeafCache, i = 0) = iterate(c.dict, i)
104+
Base.length(c::LeafCache) = length(c.dict)
105+
91106
# default all rules to first order calls
92107
apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)
93108

0 commit comments

Comments
 (0)