|
45 | 45 | function update!(tree, model, grad)
|
46 | 46 | # First walk is to accumulate the gradient. This recursion visits every copy of
|
47 | 47 | # shared leaves, but stops when branches are absent from the gradient:
|
48 |
| - dict = IdDict{Leaf, Any}() |
49 |
| - grads!(dict, tree, model, grad) |
50 |
| - # Second walk is to update the model, using same fmap walk as setup, thus each Leaf exactly once: |
| 48 | + gdict = IdDict{Leaf, Any}() |
| 49 | + grads!(gdict, tree, model, grad) |
| 50 | + # Second walk is to update the model, using same fmap walk as setup: |
| 51 | + xdict = IdDict{Leaf, Any}() # (this exists to allow for shared ℓ without shared x) |
51 | 52 | newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ
|
52 | 53 | ℓ isa Leaf || error("this state does not match the model, expected a Leaf here")
|
53 | 54 | ℓ.frozen && return x
|
54 |
| - haskey(dict, ℓ) || return x |
55 |
| - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) |
| 55 | + haskey(gdict, ℓ) || return x # no gradient seen, nothing to do |
| 56 | + if haskey(xdict, ℓ) |
| 57 | + # This means that shared ℓ encodes sharing not noted in x. Won't happen with setup above, no API yet. |
| 58 | + x′ = xdict[ℓ] # ... and is why xdict exists. |
| 59 | + size(x′) == size(x) || error("the same Leaf belongs to arrays of size $(size(x)) and $(size(x′))") |
| 60 | + return x′ |
| 61 | + end |
| 62 | + s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, gdict[ℓ]) |
56 | 63 | ℓ.state = s′ # to get state out of here, rely on mutability of Leaf
|
57 |
| - subtract!(x, x̄′) |
| 64 | + xdict[ℓ] = subtract!(x, x̄′) |
58 | 65 | end
|
59 | 66 | tree, newmodel # note that tree is guaranteed to be updated
|
60 | 67 | end
|
|
0 commit comments