Skip to content

Commit adb956e

Browse files
committed
tweak, simplify recursion
1 parent 78f2dae commit adb956e

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

src/adjust.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
"""
66
Optimisers.freeze!(tree)
77
8-
Temporarily alters the state `tree = setup(rule, model)` so that parameters will not be updated.
9-
Can be applied to the state corresponding to only part of a model, for instance `model.layers[1]`.
10-
Un-done by [`thaw!`](@ref Optimisers.thaw).
8+
Temporarily alters the state `tree = setup(rule, model)` so that parameters
9+
will not be updated. Un-done by [`thaw!`](@ref Optimisers.thaw!).
10+
11+
Can be applied to the state corresponding to only part of a model,
12+
for instance with `model::Chain`, to freeze `model.layers[1]` you
13+
should call `freeze!(tree.layers[1])`.
1114
1215
# Example
1316
```jldoctest
@@ -31,16 +34,16 @@ julia> s.x
3134
(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ())
3235
```
3336
"""
34-
freeze!(tree) = (fmapstructure(freeze!, tree; exclude = x -> x isa Leaf); nothing)
37+
freeze!(tree) = foreach(freeze!, tree)
3538
freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing)
3639

3740
"""
3841
Optimisers.thaw!(tree)
3942
40-
Un-does [`freeze!`](@ref Optimisers.freeze!) for all parameters,
41-
mutating every `Leaf(rule, state, true)` to `Leaf(rule, state, false)`.
43+
The reverse of [`freeze!`](@ref Optimisers.freeze!). Applies to all parameters,
44+
mutating every `Leaf(rule, state, frozen = true)` to `Leaf(rule, state, frozen = false)`.
4245
"""
43-
thaw!(tree) = (fmapstructure(thaw!, tree; exclude = x -> x isa Leaf); nothing)
46+
thaw!(tree) = foreach(thaw!, tree)
4447
thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing)
4548

4649
freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function Base.show(io::IO, ℓ::Leaf; colour = ℓ.frozen ? :cyan : :green)
4949
str = sprint(show, ℓ.rule; context = ioc)
5050
printstyled(io, "Leaf(", str, ", "; color = colour)
5151
show(ioc, ℓ.state)
52-
printstyled(io, ℓ.frozen ? ", frozen=true)" : ")"; color = colour)
52+
printstyled(io, ℓ.frozen ? ", frozen = true)" : ")"; color = colour)
5353
end
5454

5555
###

0 commit comments

Comments
 (0)