Skip to content

Commit a25b681

Browse files
authored
Add freeze!/thaw! (#112)
* add freeze/thaw * make a keyword frozen, and print it * tweak, simplify recursion * also block adjust * add tests * add docs * decouple from adjust, tweak words * tweak the doc example
1 parent 7d0c939 commit a25b681

File tree

5 files changed

+114
-7
lines changed

5 files changed

+114
-7
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ Optimisers.setup
3535
Optimisers.update
3636
Optimisers.update!
3737
Optimisers.adjust(::Any, ::Real)
38+
Optimisers.freeze!
39+
Optimisers.thaw!
3840
```
3941

4042
Calling `Functors.@functor` on your model's layer types by default causes

docs/src/index.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,33 @@ Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chid
138138
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
139139
```
140140

141+
## Frozen Parameters
142+
143+
To temporarily prevent training from affecting some parameters,
144+
use [freeze!](@ref Optimisers.freeze!) and `thaw!`.
145+
They work by mutating all `Leaf`s of the state tree, or part of it.
146+
147+
```julia
148+
using Flux, Optimisers
149+
150+
x = randn(Float32, 28, 28, 1, 1);
151+
net = @autosize (size(x)...,) Chain(
152+
Conv((3, 3), 1 => 3, stride=2, bias=false), Flux.flatten, Dense(_ => 2, relu),
153+
)
154+
opt = Optimisers.setup(Optimisers.Momentum(), net);
155+
156+
net.layers[3] isa Dense # now freeze this layer's parameters:
157+
Optimisers.freeze!(opt.layers[3])
158+
opt.layers[3].bias # confirm: Leaf(Momentum(...), [0.0, 0.0], frozen = true)
159+
160+
Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...);
161+
162+
net.layers[3].bias # stil zero, and its momentum is too:
163+
164+
Optimisers.thaw!(opt)
165+
opt.layers[3].bias # Leaf(Momentum(...), [0.0, 0.0])
166+
```
167+
141168
## Tied Parameters
142169

143170
If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
@@ -159,7 +186,7 @@ st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
159186
This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
160187
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
161188
from StaticArrays.jl.
162-
189+
163190

164191
## Obtaining a flat parameter vector
165192

src/adjust.jl

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,59 @@
1+
###
2+
### freezing
3+
###
4+
5+
"""
6+
Optimisers.freeze!(tree)
7+
8+
Temporarily alters the state `tree = setup(rule, model)` so that parameters
9+
will not be updated. Un-done by [`thaw!`](@ref Optimisers.thaw!).
10+
11+
Can be applied to the state corresponding to only part of a model,
12+
for instance with `model::Chain`, to freeze `model.layers[1]` you
13+
should call `freeze!(tree.layers[1])`.
14+
15+
# Example
16+
```jldoctest
17+
julia> m = (x = ([1.0], 2.0), y = [3.0]);
18+
19+
julia> s = Optimisers.setup(Momentum(), m);
20+
21+
julia> Optimisers.freeze!(s.x)
22+
23+
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient
24+
25+
julia> m
26+
(x = ([1.0], 2.0), y = [-0.14159258336972558])
27+
28+
julia> s
29+
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))
30+
31+
julia> Optimisers.thaw!(s)
32+
33+
julia> s.x
34+
(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ())
35+
```
36+
"""
37+
freeze!(tree) = foreach(freeze!, tree)
38+
freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing)
39+
40+
"""
41+
Optimisers.thaw!(tree)
42+
43+
The reverse of [`freeze!`](@ref Optimisers.freeze!). Applies to all parameters,
44+
mutating every `Leaf(rule, state, frozen = true)` to `Leaf(rule, state, frozen = false)`.
45+
"""
46+
thaw!(tree) = foreach(thaw!, tree)
47+
thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing)
48+
49+
freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(
50+
"`freeze!` must not be applied to a model, only to the state tree from `setup`"))
51+
thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(
52+
"`thaw!` must not be applied to a model, only to the state tree from `setup`"))
53+
54+
###
55+
### adjust
56+
###
157

258
"""
359
Optimisers.adjust(tree, η) -> tree
@@ -47,8 +103,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
47103
adjust(::Nothing, ::Real) = nothing
48104
adjust(::Nothing; kw...) = nothing
49105

50-
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state)
51-
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state)
106+
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
107+
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)
52108

53109

54110
"""

src/interface.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ abstract type AbstractRule end
1010
### setup
1111
###
1212

13-
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
13+
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing...
1414
rule::R
1515
state::S
16+
frozen::Bool # ... and to allow freeze! to act on this.
1617
end
18+
Leaf(rule, state; frozen::Bool = false) = Leaf(rule, state, frozen)
1719

1820
@functor Leaf
1921

@@ -42,11 +44,12 @@ function _setup(rule, x; cache)
4244
end
4345
end
4446

45-
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
47+
function Base.show(io::IO, ℓ::Leaf; colour =.frozen ? :cyan : :green)
4648
ioc = IOContext(io, :compact => true)
47-
print(ioc, "Leaf(", ℓ.rule, ", ")
49+
str = sprint(show, ℓ.rule; context = ioc) # produces Adam{Float32}(0.001, ... not 0.001f0
50+
printstyled(io, "Leaf(", str, ", "; color = colour)
4851
show(ioc, ℓ.state)
49-
print(ioc, ")")
52+
printstyled(io, ℓ.frozen ? ", frozen = true)" : ")"; color = colour)
5053
end
5154

5255
###
@@ -83,6 +86,7 @@ function _update!(tree, x; grads, params)
8386
end
8487
function _update!(ℓ::Leaf, x; grads, params)
8588
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
89+
.frozen && return x
8690
params[(ℓ,x)] = if haskey(grads, ℓ)
8791
.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
8892
subtract!(x, x̄′)

test/runtests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,24 @@ end
221221
@test sc2.γ.state[2][1] [0.1, 0.2, 0.2]
222222
end
223223

224+
@testset "freeze/thaw" begin
225+
m = (x=[1.0, 2.0], y=([3.0, 4.0], sin));
226+
st = Optimisers.setup(Descent(0.1), m);
227+
Optimisers.freeze!(st.y)
228+
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
229+
@test m.x [0.9, 1.0]
230+
@test m.y[1] == [3, 4]
231+
232+
st = Optimisers.adjust(st, 0.2)
233+
Optimisers.thaw!(st)
234+
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
235+
@test m.y[1] [-17.0, -196.0]
236+
@test m.x [0.7, -1.0]
237+
238+
@test_throws ArgumentError Optimisers.freeze!(m)
239+
@test_throws ArgumentError Optimisers.thaw!(m)
240+
end
241+
224242
@testset "forgotten gradient" begin
225243
x = [1.0, 2.0]
226244
sx = Optimisers.setup(Descent(), x)

0 commit comments

Comments
 (0)