Skip to content

Commit 5c0045a

Browse files
committed
remove leaf.frozen field
1 parent e17e474 commit 5c0045a

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

src/adjust.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
4747
adjust(::Nothing, ::Real) = nothing
4848
adjust(::Nothing; kw...) = nothing
4949

50-
adjust(ℓ::Leaf, eta::Real) = .frozen ?: Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
51-
adjust(ℓ::Leaf; kw...) = .frozen ?: Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)
50+
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state)
51+
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state)
5252

5353

5454
"""

src/interface.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ abstract type AbstractRule end
1313
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
1414
rule::R
1515
state::S
16-
frozen::Bool # mutability also allows this flag to be changed
1716
end
1817

1918
@functor Leaf
@@ -35,7 +34,7 @@ function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long
3534
ioc = IOContext(io, :compact => true)
3635
print(ioc, "Leaf(", ℓ.rule, ", ")
3736
show(ioc, ℓ.state)
38-
print(ioc, ", ", ℓ.frozen, ")")
37+
print(ioc, ")")
3938
end
4039

4140
###
@@ -49,7 +48,6 @@ function update!(tree, model, grad)
4948
grads!(dict, tree, model, grad)
5049
# Second walk is to update the model. The walk taken follows Leaf identity
5150
newmodel = fmap(tree, model; exclude =->isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x
52-
.frozen && return x
5351
haskey(dict, ℓ) || return x # no gradient seen, nothing to do
5452
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
5553
.state = s′ # to get state out of here, rely on mutability of Leaf

0 commit comments

Comments
 (0)