Skip to content

Commit 674527e

Browse files
authored
Merge pull request #1875 from darsnack/load-structured
Add a structural `loadparams!`
2 parents 5f17f1c + 6b533b8 commit 674527e

File tree

8 files changed

+269
-37
lines changed

8 files changed

+269
-37
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ been removed in favour of MLDatasets.jl.
1212
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
1313
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.
1414
* The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function.
15+
* `loadparams!` is replaced by [`loadmodel!`](https://github.com/FluxML/Flux.jl/pull/1875) which copies trainable + non-trainable parameters and performs more thorough structural checking
1516

1617
## v0.12.10
1718
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Adapt = "3.0"
2929
ArrayInterface = "3.1, 4, 5"
3030
CUDA = "3"
3131
ChainRulesCore = "1.12"
32-
Functors = "0.2.1"
32+
Functors = "0.2.8"
3333
MLUtils = "0.2"
3434
MacroTools = "0.5"
3535
NNlib = "0.8.2"

docs/src/saving.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
You may wish to save models so that they can be loaded and run in a later
44
session. The easiest way to do this is via
5-
[BSON.jl](https://github.com/MikeInnes/BSON.jl).
5+
[BSON.jl](https://github.com/JuliaIO/BSON.jl).
66

77
Save a model:
88

@@ -36,7 +36,6 @@ Chain(
3636
Dense(5 => 2), # 12 parameters
3737
NNlib.softmax,
3838
) # Total: 4 arrays, 67 parameters, 524 bytes.
39-
4039
```
4140

4241
Models are just normal Julia structs, so it's fine to use any Julia storage
@@ -46,15 +45,17 @@ versions of Flux).
4645

4746
!!! note
4847

49-
If a saved model's weights are stored on the GPU, the model will not load
48+
If a saved model's parameters are stored on the GPU, the model will not load
5049
later on if there is no GPU support available. It's best to [move your model
5150
to the CPU](gpu.md) with `cpu(model)` before saving it.
5251

53-
## Saving Model Weights
52+
!!! warning
5453

55-
In some cases it may be useful to save only the model parameters themselves, and
56-
rebuild the model architecture in your code. You can use `params(model)` to get
57-
model parameters.
54+
Previous versions of Flux suggested saving only the model weights using
55+
`@save "mymodel.bson" params(model)`.
56+
This is no longer recommended and even strongly discouraged.
57+
Saving models this way will only store the trainable parameters which
58+
will result in incorrect behavior for layers like `BatchNorm`.
5859

5960
```Julia
6061
julia> using Flux
@@ -64,28 +65,27 @@ Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
6465

6566
julia> weights = Flux.params(model);
6667

67-
julia> using BSON: @save
68-
69-
julia> @save "mymodel.bson" weights
70-
```
71-
72-
You can easily load parameters back into a model with `Flux.loadparams!`.
68+
Loading the model as shown above will return a new model with the stored parameters.
69+
But sometimes you already have a model, and you want to load stored parameters into it.
70+
This can be done as
7371

7472
```julia
75-
julia> using Flux
73+
using Flux: loadmodel!
74+
using BSON: @load
7675
77-
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
78-
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
76+
# some predefined model
77+
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
7978
80-
julia> using BSON: @load
79+
# load one model into another
80+
model = loadmodel!(model, @load("mymodel.bson"))
81+
```
8182

82-
julia> @load "mymodel.bson" weights
83+
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.
8384

84-
julia> Flux.loadparams!(model, weights)
85+
```@docs
86+
Flux.loadmodel!
8587
```
8688

87-
The new `model` we created will now be identical to the one we saved parameters for.
88-
8989
## Checkpointing
9090

9191
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).

src/Flux.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ include("layers/normalise.jl")
4646
include("layers/upsample.jl")
4747
include("layers/show.jl")
4848

49+
include("loading.jl")
50+
4951
include("outputsize.jl")
5052

5153
include("data/Data.jl")

src/deprecations.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ function Diagonal(size::Tuple; kw...)
4848
Scale(size...; kw...)
4949
end
5050

51+
# Deprecate this eventually once saving models w/o structure is no more
52+
function loadparams!(m, xs)
53+
Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!)
54+
for (p, x) in zip(params(m), xs)
55+
size(p) == size(x) ||
56+
error("Expected param size $(size(p)), got $(size(x))")
57+
copyto!(p, x)
58+
end
59+
end
60+
5161
# Channel notation: Changed to match Conv, but very softly deprecated!
5262
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
5363
Dense(in::Integer, out::Integer, σ = identity; kw...) =

src/functor.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,6 @@ function params(m...)
8585
return ps
8686
end
8787

88-
function loadparams!(m, xs)
89-
for (p, x) in zip(params(m), xs)
90-
size(p) == size(x) ||
91-
error("Expected param size $(size(p)), got $(size(x))")
92-
copyto!(p, x)
93-
end
94-
end
95-
9688
struct FluxCUDAAdaptor end
9789
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
9890
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))

src/loading.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
loadleaf!(dst, src, err) = dst
2+
loadleaf!(dst::AbstractArray, src, err) =
3+
error("Tried to copy $src into an array destination; this is not allowed.")
4+
loadleaf!(dst, src::AbstractArray, err) =
5+
error("Tried to copy an array to $dst; this is not allowed.")
6+
function loadleaf!(dst::AbstractArray, src::Bool, err)
7+
if iszero(src)
8+
dst .= src
9+
else
10+
error("Cannot copy boolean parameter == true to non-zero parameter.")
11+
end
12+
return dst
13+
end
14+
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst :
15+
error("Cannot copy non-zero parameter to boolean parameter == true.")
16+
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
17+
(size(dst) == size(src)) || throw(err)
18+
copyto!(dst, src)
19+
end
20+
21+
_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) ||
22+
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
23+
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) ||
24+
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
25+
_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) ||
26+
error("Encountered tied destination parameters with untied and mismatched sources.")
27+
_tie_check(dst, src) = true
28+
29+
_bool_tie_check(dst, src) = true
30+
31+
"""
32+
loadmodel!(dst, src)
33+
34+
Copy all the parameters (trainable and non-trainable) from `src` into `dst`.
35+
36+
Recursively walks `dst` and `src` together using [`Functors.children`](@ref),
37+
and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch.
38+
Non-array elements (such as activation functions) are not copied and need not match.
39+
Zero bias vectors and `bias=false` are considered equivalent
40+
(see extended help for more details).
41+
42+
# Examples
43+
```julia
44+
julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh)), Dense(2 => 1; bias = [1f0]))
45+
Chain(
46+
Dense(5 => 2, tanh), # 12 parameters
47+
Dense(2 => 1), # 3 parameters
48+
) # Total: 4 arrays, 15 parameters, 316 bytes.
49+
50+
julia> dst[1].weight ≈ ones(2, 5) # by construction
51+
true
52+
53+
julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false));
54+
55+
julia> Flux.loadmodel!(dst, src);
56+
57+
julia> dst[1].weight ≈ ones(2, 5) # values changed
58+
false
59+
60+
julia> iszero(dst[2].bias)
61+
true
62+
```
63+
64+
# Extended help
65+
66+
Throws an error when:
67+
- `dst` and `src` do not share the same fields (at any level)
68+
- the sizes of leaf nodes are mismatched between `dst` and `src`
69+
- copying non-array values to/from an array parameter
70+
(except inactive parameters described below)
71+
- `dst` is a "tied" parameter (i.e. refers to another parameter) and
72+
loaded into multiple times with mismatched source values
73+
74+
Inactive parameters can be encoded by using the boolean value `false` instead of an array.
75+
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied);
76+
however, attempting to copy a non-zero array to an inactive parameter will throw an error.
77+
Likewise, copying a `src` value of `false` to any `dst` array is valid,
78+
but copying a `src` value of `true` will error.
79+
"""
80+
function loadmodel!(dst, src; cache = Base.IdSet())
81+
ldsts, _ = functor(dst)
82+
lsrcs, _ = functor(src)
83+
(keys(ldsts) == keys(lsrcs)) ||
84+
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
85+
86+
err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
87+
foreach(ldsts, lsrcs) do ldst, lsrc
88+
if ldst in cache # we already loaded this parameter before
89+
_tie_check(ldst, lsrc) && return ldst
90+
elseif Functors.isleaf(ldst) # our first time loading this leaf
91+
push!(cache, ldst)
92+
loadleaf!(ldst, lsrc, err)
93+
else # this isn't a leaf
94+
loadmodel!(ldst, lsrc; cache = cache)
95+
end
96+
end
97+
98+
return dst
99+
end

0 commit comments

Comments
 (0)