Skip to content

Commit e84b61b

Browse files
committed
one more dict to allow artificial ties
1 parent 640765a commit e84b61b

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/interface.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,23 @@ end
4545
function update!(tree, model, grad)
4646
# First walk is to accumulate the gradient. This recursion visits every copy of
4747
# shared leaves, but stops when branches are absent from the gradient:
48-
dict = IdDict{Leaf, Any}()
49-
grads!(dict, tree, model, grad)
50-
# Second walk is to update the model, using same fmap walk as setup, thus each Leaf exactly once:
48+
gdict = IdDict{Leaf, Any}()
49+
grads!(gdict, tree, model, grad)
50+
# Second walk is to update the model, using same fmap walk as setup:
51+
xdict = IdDict{Leaf, Any}() # (this exists to allow for shared ℓ without shared x)
5152
newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ
5253
isa Leaf || error("this state does not match the model, expected a Leaf here")
5354
.frozen && return x
54-
haskey(dict, ℓ) || return x
55-
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
55+
haskey(gdict, ℓ) || return x # no gradient seen, nothing to do
56+
if haskey(xdict, ℓ)
57+
# This means that shared ℓ encodes sharing not noted in x. Won't happen with setup above, no API yet.
58+
x′ = xdict[ℓ] # ... and is why xdict exists.
59+
size(x′) == size(x) || error("the same Leaf belongs to arrays of size $(size(x)) and $(size(x′))")
60+
return x′
61+
end
62+
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, gdict[ℓ])
5663
.state = s′ # to get state out of here, rely on mutability of Leaf
57-
subtract!(x, x̄′)
64+
xdict[ℓ] = subtract!(x, x̄′)
5865
end
5966
tree, newmodel # note that tree is guaranteed to be updated
6067
end

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,22 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
292292
# Error: no constructors for type Any
293293
@test_broken s4, t4 = Optimisers.update(stri, tri, g4)
294294
end
295+
296+
@testset "artificial" begin
297+
# Interpret shared Leaf as implying shared parameters, even if this did not arise from shared arrays.
298+
# No API for setting this at the moment, but can construct one by hand:
299+
model = (a = [1,2.0], b = [1, 2.0], c = [1, 2.0], d = [1, 2.0])
300+
honest = Optimisers.setup(Momentum(0.1), model)
301+
trick = (a = honest.a, b = honest.a, c = honest.c, d= honest.d) # makes a & b shared
302+
303+
trick2, model2 = Optimisers.update(trick, model, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10]))
304+
trick3, model3 = Optimisers.update(trick2, model2, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10]))
305+
306+
@test model3.a == model3.b == model3.d # same as having the gradients added
307+
@test !(model3.a model3.c)
308+
@test trick3.a === trick3.b # leaves remain shared
309+
model3.a === model3.b # in fact arrays end up shared, but this is not required
310+
end
295311
end
296312

297313
end

0 commit comments

Comments
 (0)