Skip to content

Commit a6cdfdd

Browse files
committed
Refactor loadmodel! to use a custom recursion instead of fmap. Add more tests.
1 parent b2a2664 commit a6cdfdd

File tree

3 files changed

+138
-83
lines changed

3 files changed

+138
-83
lines changed

docs/src/saving.md

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,9 @@ This ensures that the model loaded from `"mymodel.bson"` matches the structure o
8585

8686
```@docs
8787
Flux.loadmodel!
88-
Flux.loadto!
89-
Flux.isloadleaf
9088
Flux.loadleaf!
9189
```
9290

93-
### Customizing `loadmodel!` for a custom layer
94-
95-
By default, [`loadmodel!`](@ref) will recursively walk a nested model (like a `Chain`) using [`Functors.fmap`](@ref) until it encounters a loading *leaf node*. A leaf node is defined as any node for which [`Flux.isloadleaf`](@ref) returns `true`. For example, consider the model
96-
97-
```julia
98-
model = Chain(Dense(10 => 5), Parallel(+, Dense(5 => 2), Dense(5 => 2)))
99-
```
100-
101-
Here, the `Chain` and `Parallel` layers are not leaf nodes, but all the `Dense` layers are leaf nodes. This makes sense, because `Dense` layers are the ones with parameters that we need to copy. The default behavior for [`Flux.isloadleaf`](@ref) should work for most custom layers, but you can override this function for your type.
102-
103-
Once a pair of leaf nodes is encountered, `loadmodel!` will call [`Flux.loadto!](@ref) on them. By default, this just copies the parameters from one leaf node to the other, but you can customize the behavior by overriding `loadto!` for your pair of types.
104-
10591
## Checkpointing
10692

10793
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).

src/loading.jl

Lines changed: 41 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,68 @@
11
"""
2-
isloadleaf(x)
3-
4-
Return `true` whenever `x` should be treated as a "leaf node"
5-
for the purposes of loading parameters.
6-
By default, `isloadleaf` returns `true` if [`Functors.isleaf`](@ref)
7-
is `true` for all [`Functors.children(x)`](@ref `Functors.children`).
8-
9-
You can override this function for a specific type if needed.
10-
"""
11-
isloadleaf(x) = all(Functors.isleaf, Functors.children(x))
2+
loadleaf!(dst, src, err)
123
4+
Copy `src` to `dst` or throw `err` when their sizes are mismatched.
5+
By default, use `copyto!` when `dst` and `src` are arrays.
6+
When only `dst` is an array, set every element to `src`.
7+
Otherwise, just return `dst`.
138
"""
14-
loadleaf!(x, x̄, err)
15-
16-
Copy `x̄` to `x` or throw `err` when their sizes are mismatched.
17-
By default, use `copyto!` when `x` and `x̄` are arrays.
18-
Otherwise, just return `x`.
19-
"""
20-
loadleaf!(x, x̄, err) = x
21-
function loadleaf!(x::AbstractArray, x̄, err)
22-
x .=
23-
return x
24-
end
25-
function loadleaf!(x::AbstractArray, x̄::AbstractArray, err)
26-
(size(x) == size(x̄)) || throw(err)
27-
copyto!(x, x̄)
9+
loadleaf!(dst, src, err) = dst
10+
function loadleaf!(dst::AbstractArray, src, err)
11+
dst .= src
12+
return dst
2813
end
29-
30-
"""
31-
loadto!(m, m̄)
32-
33-
Load a leaf node `m̄` into `m`.
34-
35-
By default, call [`Flux.loadleaf!`](@ref) on each pair of children
36-
in `zip(Functors.children(m), Functors.children(m̄))`.
37-
"""
38-
function loadto!(m::T, m̄::S) where {T, S}
39-
(nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m."))
40-
41-
ls, _ = functor(m)
42-
l̄s, _ = functor(m̄)
43-
(keys(ls) == keys(l̄s)) ||
44-
throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match."))
45-
46-
err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")
47-
foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s)
48-
49-
return m
14+
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
15+
(size(dst) == size(src)) || throw(err)
16+
copyto!(dst, src)
5017
end
5118

5219
"""
53-
loadmodel!(m, m̄)
20+
loadmodel!(dst, src)
5421
55-
Copy all the parameters (trainable and non-trainable) from `` to `m`.
22+
Copy all the parameters (trainable and non-trainable) from `src` to `dst`.
5623
57-
`loadmodel!` recursively walks `m` and `m̄` until it encounters
58-
a subfield, `x`, (i.e. layer) where `isloadleaf(x)` is true.
59-
The parameters of the matching subfield, `x̄`, are copied to `x`,
60-
throwing an error whenever:
61-
- `x` and `x̄` are not the same type (e.g. loading a `Conv` to a `Dense`)
62-
- `x` and `x̄` do not share the same fields
63-
- the parameter sizes are mismatched between `x` and `x̄`
24+
`loadmodel!` recursively walks the [`Functors.children`](@ref) of `dst` and `src`
25+
calling `loadleaf!` on any pair of children where [`Functors.isleaf`](@ref) is true.
26+
It throws an error whenever:
27+
- `dst` and `src` do not share the same fields (at any level)
28+
- the sizes of leaf nodes are mismatched between `dst` and `src`
6429
6530
```julia
6631
julia> using Flux: loadmodel!
6732
68-
julia> m = Chain(Dense(Flux.ones32(2, 5)), Dense(2 => 1))
33+
julia> dst = Chain(Dense(Flux.ones32(2, 5)), Dense(2 => 1))
6934
Chain(
7035
Dense(5 => 2), # 12 parameters
7136
Dense(2 => 1), # 3 parameters
7237
) # Total: 4 arrays, 15 parameters, 316 bytes.
7338
74-
julia> = Chain(Dense(5 => 2), Dense(2 => 1));
39+
julia> src = Chain(Dense(5 => 2), Dense(2 => 1));
7540
76-
julia> all(isone, m[1].weight)
41+
julia> all(isone, dst[1].weight)
7742
true
7843
79-
julia> m = loadmodel!(m, m̄)
44+
julia> dst = loadmodel!(dst, src)
8045
Chain(
8146
Dense(5 => 2), # 12 parameters
8247
Dense(2 => 1), # 3 parameters
8348
) # Total: 4 arrays, 15 parameters, 316 bytes.
8449
85-
julia> all(isone, m[1].weight)
50+
julia> all(isone, dst[1].weight)
8651
false
8752
88-
julia> m[1].weight == [1].weight
53+
julia> dst[1].weight == src[1].weight
8954
true
9055
91-
julia> m[2].bias == [2].bias
56+
julia> dst[2].bias == src[2].bias
9257
true
9358
```
9459
9560
See [`Flux.loadleaf!`](@ref) for more details on the copy behavior.
96-
See [`Flux.isloadleaf`](@ref) for more details on which layers are considered leaves.
9761
9862
!!! warning
99-
This function allows `` to be a vector or `Params` for backwards-compatibility.
63+
This function allows `src` to be a `Params` for backwards-compatibility.
10064
You should avoid using `loadmodel!` this way, because it skips most of the structural
101-
checking used when `` is also a struct. Silent errors may occur.
65+
checking used when `src` is also a nested structure. Silent errors may occur.
10266
"""
10367
function loadmodel!(m, xs::Params)
10468
for (p, x) in zip(params(m), xs)
@@ -107,5 +71,16 @@ function loadmodel!(m, xs::Params)
10771
copyto!(p, x)
10872
end
10973
end
110-
loadmodel!(m, xs::AbstractVector) = loadmodel!(m, params(xs))
111-
loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = isloadleaf)
74+
function loadmodel!(dst, src)
75+
ldsts, _ = functor(dst)
76+
lsrcs, _ = functor(src)
77+
(keys(ldsts) == keys(lsrcs)) ||
78+
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
79+
80+
err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
81+
foreach(ldsts, lsrcs) do ldst, lsrc
82+
Functors.isleaf(ldst) ? loadleaf!(ldst, lsrc, err) : loadmodel!(ldst, lsrc)
83+
end
84+
85+
return dst
86+
end

test/utils.jl

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ end
378378
end
379379
end
380380

381-
@testset "loadmodel!(m, m̄)" begin
381+
@testset "loadmodel!(dst, src)" begin
382382
import Flux: loadmodel!, Zeros
383383

384384
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
@@ -389,17 +389,111 @@ end
389389
m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2)))
390390

391391
loadmodel!(m1, m2)
392+
# trainable parameters copy over
392393
@test m1[1].weight == m2[1].weight
393394
@test m1[1].bias == m2[1].bias
395+
# non-array leaves are untouched
394396
@test m1[2].σ == relu
397+
395398
loadmodel!(m5, m6)
399+
# more complex nested structures also work
396400
@test m5[1].weight == m6[1].weight
397401
@test m5[2][1].weight == m6[2][1].weight
398-
@test m5[2][1].bias == Zeros()
402+
# false bias is not overwritten
403+
@test m5[2][1].bias == false
399404

405+
# mismatched nodes throw an error
400406
@test_throws ArgumentError loadmodel!(m1, m3)
401-
@test_throws DimensionMismatch loadmodel!(m1, m4)
402407
@test_throws ArgumentError loadmodel!(m1, m5)
408+
# size mismatches throw an error
409+
@test_throws DimensionMismatch loadmodel!(m1, m4)
410+
411+
m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2))
412+
m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1))
413+
m2[2].μ .= rand(Float32, size(m2[2].μ)...)
414+
loadmodel!(m1, m2)
415+
# non-trainable parameters are copied as well
416+
@test m1[2].μ == m2[2].μ
417+
# functions are not copied
418+
@test m1[3] == Flux.flatten
419+
# dropout rate is not copied
420+
@test m1[4].p == 0.2
421+
422+
# from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33)
423+
# tests Chain(...) vs Chain([...])
424+
# tests MaxPool
425+
# tests testmode!/trainmode! is not copied
426+
# tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model
427+
chain1 = Chain(Dropout(0.2),
428+
Conv((3, 3), 1 => 32, relu),
429+
BatchNorm(32, relu),
430+
MaxPool((2, 2)),
431+
Dropout(0.2),
432+
Conv((3, 3), 32 => 16, relu),
433+
Dropout(0.2),
434+
MaxPool((2, 2)),
435+
Dropout(0.2),
436+
Conv((3, 3), 16 => 10, relu),
437+
Dropout(0.2),
438+
x -> reshape(x, :, size(x, 4)),
439+
Dropout(0.2),
440+
Dense(90, 10),
441+
softmax)
442+
chain2 = Chain([Dropout(0.1),
443+
Conv((3, 3), 1 => 32, relu),
444+
BatchNorm(32, relu),
445+
MaxPool((3, 3)),
446+
Dropout(0.1),
447+
Conv((3, 3), 32 => 16, relu),
448+
Dropout(0.1),
449+
MaxPool((3, 3)),
450+
Dropout(0.1),
451+
Conv((3, 3), 16 => 10, relu),
452+
Dropout(0.1),
453+
x -> reshape(x, :, size(x, 4)),
454+
Dropout(0.1),
455+
Dense(90, 10),
456+
softmax])
457+
chain2[3].μ .= 5f0
458+
chain2[3].σ² .= 2f0
459+
testmode!(chain2)
460+
loadmodel!(chain1, chain2)
461+
for (dst, src) in zip(chain1, chain2)
462+
if dst isa Dropout
463+
@test dst.p == 0.2
464+
elseif dst isa Union{Conv, Dense}
465+
@test dst.weight == src.weight
466+
@test dst.bias == src.bias
467+
elseif dst isa MaxPool
468+
@test dst.k == (2, 2)
469+
elseif dst isa BatchNorm
470+
@test dst.μ == src.μ
471+
@test dst.σ² == src.σ²
472+
@test isnothing(dst.active)
473+
end
474+
end
475+
476+
# copy only a subset of the model
477+
chain1[end - 1].weight .= 1f0
478+
chain1[3].μ .= 3f0
479+
chain1[2].bias .= 5f0
480+
loadmodel!(chain2[end - 1], chain1[end - 1])
481+
loadmodel!(chain2[3], chain1[3])
482+
@test chain2[end - 1].weight == chain1[end - 1].weight
483+
@test chain2[3].μ == chain1[3].μ
484+
@test chain2[2].bias != chain1[2].bias
485+
486+
# test shared weights
487+
m1 = Chain(Dense(10 => 5), Dense(5 => 2))
488+
m2 = Chain(Dense(transpose(m1[2].weight)), Dense(permutedims(m1[1].weight)))
489+
m3 = Chain(Dense(m1[1].weight), Dense(m1[2].weight))
490+
m2[2].weight .= 1f0
491+
loadmodel!(m1, m3)
492+
@test m1[2].weight === parent(m2[1].weight)
493+
@test m1[2].weight == transpose(m2[1].weight)
494+
@test m1[1].weight === m3[1].weight
495+
@test m2[2].weight != transpose(m1[1].weight)
496+
@test m3[2].weight == transpose(m2[1].weight)
403497
end
404498

405499
@testset "destructure" begin

0 commit comments

Comments
 (0)