Skip to content

Commit b2a2664

Browse files
committed
Better docs
1 parent fbc9faf commit b2a2664

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

docs/src/saving.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,23 @@ This ensures that the model loaded from `"mymodel.bson"` matches the structure o
8585

8686
```@docs
8787
Flux.loadmodel!
88+
Flux.loadto!
8889
Flux.isloadleaf
8990
Flux.loadleaf!
9091
```
9192

93+
### Customizing `loadmodel!` for a custom layer
94+
95+
By default, [`loadmodel!`](@ref) will recursively walk a nested model (like a `Chain`) using [`Functors.fmap`](@ref) until it encounters a loading *leaf node*. A leaf node is defined as any node for which [`Flux.isloadleaf`](@ref) returns `true`. For example, consider the model
96+
97+
```julia
98+
model = Chain(Dense(10 => 5), Parallel(+, Dense(5 => 2), Dense(5 => 2)))
99+
```
100+
101+
Here, the `Chain` and `Parallel` layers are not leaf nodes, but all the `Dense` layers are leaf nodes. This makes sense, because `Dense` layers are the ones with parameters that we need to copy. The default behavior for [`Flux.isloadleaf`](@ref) should work for most custom layers, but you can override this function for your type.
102+
103+
Once a pair of leaf nodes is encountered, `loadmodel!` will call [`Flux.loadto!](@ref) on them. By default, this just copies the parameters from one leaf node to the other, but you can customize the behavior by overriding `loadto!` for your pair of types.
104+
92105
## Checkpointing
93106
94107
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).

src/loading.jl

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,17 @@ function loadleaf!(x::AbstractArray, x̄::AbstractArray, err)
2727
copyto!(x, x̄)
2828
end
2929

30-
function _loadto!(m, m̄)
30+
"""
31+
loadto!(m, m̄)
32+
33+
Load a leaf node `m̄` into `m`.
34+
35+
By default, call [`Flux.loadleaf!`](@ref) on each pair of children
36+
in `zip(Functors.children(m), Functors.children(m̄))`.
37+
"""
38+
function loadto!(m::T, m̄::S) where {T, S}
39+
(nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m."))
40+
3141
ls, _ = functor(m)
3242
l̄s, _ = functor(m̄)
3343
(keys(ls) == keys(l̄s)) ||
@@ -38,10 +48,6 @@ function _loadto!(m, m̄)
3848

3949
return m
4050
end
41-
function loadto!(m::T, m̄::S) where {T, S}
42-
(nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m."))
43-
_loadto!(m, m̄)
44-
end
4551

4652
"""
4753
loadmodel!(m, m̄)
@@ -56,8 +62,38 @@ throwing an error whenever:
5662
- `x` and `x̄` do not share the same fields
5763
- the parameter sizes are mismatched between `x` and `x̄`
5864
59-
See [`loadleaf!`](@ref) for more details on the copy behavior.
60-
See [`isloadleaf`](@ref) for more details on which layers are considered leaves.
65+
```julia
66+
julia> using Flux: loadmodel!
67+
68+
julia> m = Chain(Dense(Flux.ones32(2, 5)), Dense(2 => 1))
69+
Chain(
70+
Dense(5 => 2), # 12 parameters
71+
Dense(2 => 1), # 3 parameters
72+
) # Total: 4 arrays, 15 parameters, 316 bytes.
73+
74+
julia> m̄ = Chain(Dense(5 => 2), Dense(2 => 1));
75+
76+
julia> all(isone, m[1].weight)
77+
true
78+
79+
julia> m = loadmodel!(m, m̄)
80+
Chain(
81+
Dense(5 => 2), # 12 parameters
82+
Dense(2 => 1), # 3 parameters
83+
) # Total: 4 arrays, 15 parameters, 316 bytes.
84+
85+
julia> all(isone, m[1].weight)
86+
false
87+
88+
julia> m[1].weight == m̄[1].weight
89+
true
90+
91+
julia> m[2].bias == m̄[2].bias
92+
true
93+
```
94+
95+
See [`Flux.loadleaf!`](@ref) for more details on the copy behavior.
96+
See [`Flux.isloadleaf`](@ref) for more details on which layers are considered leaves.
6197
6298
!!! warning
6399
This function allows `m̄` to be a vector or `Params` for backwards-compatibility.

0 commit comments

Comments
 (0)