Skip to content

Commit f046185

Browse files
committed
remove leaf.frozen field
1 parent e17e474 commit f046185

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

src/adjust.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
4747
adjust(::Nothing, ::Real) = nothing
4848
adjust(::Nothing; kw...) = nothing
4949

50-
adjust(ℓ::Leaf, eta::Real) = .frozen ?: Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
51-
adjust(ℓ::Leaf; kw...) = .frozen ?: Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)
50+
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state)
51+
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state)
5252

5353

5454
"""

src/interface.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ abstract type AbstractRule end
1313
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
1414
rule::R
1515
state::S
16-
frozen::Bool # mutability also allows this flag to be changed
1716
end
1817

1918
@functor Leaf
@@ -25,7 +24,7 @@ function setup(rule::AbstractRule, model)
2524
# Rely on Functors to identify shared arrays, they will share a Leaf in this tree:
2625
tree = fmapstructure(model, exclude = isnumeric) do x
2726
cnt[] += 1
28-
Leaf(rule, init(rule, x), false)
27+
Leaf(rule, init(rule, x))
2928
end
3029
cnt[] == 0 && @warn "setup found no parameters in the given model"
3130
tree
@@ -35,7 +34,7 @@ function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long
3534
ioc = IOContext(io, :compact => true)
3635
print(ioc, "Leaf(", ℓ.rule, ", ")
3736
show(ioc, ℓ.state)
38-
print(ioc, ", ", ℓ.frozen, ")")
37+
print(ioc, ")")
3938
end
4039

4140
###
@@ -49,7 +48,6 @@ function update!(tree, model, grad)
4948
grads!(dict, tree, model, grad)
5049
# Second walk is to update the model. The walk taken follows Leaf identity
5150
newmodel = fmap(tree, model; exclude =->isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x
52-
.frozen && return x
5351
haskey(dict, ℓ) || return x # no gradient seen, nothing to do
5452
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
5553
.state = s′ # to get state out of here, rely on mutability of Leaf

0 commit comments

Comments
 (0)