Skip to content

Commit 29662b2

Browse files
committed
Add better support for loadmodel! w/ tied parameters and address some other review comments
1 parent a6cdfdd commit 29662b2

File tree

3 files changed

+78
-45
lines changed

3 files changed

+78
-45
lines changed

src/deprecations.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
1515
ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type"))
1616
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))
1717

18-
@deprecate loadparams!(m, xs) loadmodel!(m, xs)
19-
2018
# v0.13 deprecations
2119

2220
function Broadcast.broadcasted(f::Recur, args...)
@@ -50,6 +48,16 @@ function Diagonal(size::Tuple; kw...)
5048
Scale(size...; kw...)
5149
end
5250

51+
# Deprecate this eventually once saving models w/o structure is no more
52+
function loadparams!(m, xs)
53+
Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!)
54+
for (p, x) in zip(params(m), xs)
55+
size(p) == size(x) ||
56+
error("Expected param size $(size(p)), got $(size(x))")
57+
copyto!(p, x)
58+
end
59+
end
60+
5361
# Channel notation: Changed to match Conv, but very softly deprecated!
5462
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
5563
Dense(in::Integer, out::Integer, σ = identity; kw...) =

src/loading.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,42 @@
1-
"""
2-
loadleaf!(dst, src, err)
3-
4-
Copy `src` to `dst` or throw `err` when their sizes are mismatched.
5-
By default, use `copyto!` when `dst` and `src` are arrays.
6-
When only `dst` is an array, set every element to `src`.
7-
Otherwise, just return `dst`.
8-
"""
91
loadleaf!(dst, src, err) = dst
10-
function loadleaf!(dst::AbstractArray, src, err)
11-
dst .= src
2+
function loadleaf!(dst::AbstractArray, src::Bool, err)
3+
if iszero(src)
4+
dst .= src
5+
else
6+
error("Cannot copy boolean parameter == true to non-zero parameter.")
7+
end
128
return dst
139
end
10+
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst :
11+
error("Cannot copy non-zero parameter to boolean parameter == true.")
1412
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
1513
(size(dst) == size(src)) || throw(err)
1614
copyto!(dst, src)
1715
end
1816

17+
_parent(x) = x
18+
_parent(x::AbstractArray) = parent(x)
19+
20+
_tie_check(dst::AbstractArray, src::AbstractArray) = dst == src
21+
_tie_check(dst, src) = true
22+
23+
_bool_tie_check(dst::Bool, src::AbstractArray) = iszero(dst)
24+
_bool_tie_check(dst::AbstractArray, src::Bool) = iszero(dst) && iszero(src)
25+
_bool_tie_check(dst, src) = true
26+
1927
"""
2028
loadmodel!(dst, src)
2129
2230
Copy all the parameters (trainable and non-trainable) from `src` to `dst`.
2331
2432
`loadmodel!` recursively walks the [`Functors.children`](@ref) of `dst` and `src`
25-
calling `loadleaf!` on any pair of children where [`Functors.isleaf`](@ref) is true.
33+
calling `copyto!` on any pair of children where [`Functors.isleaf`](@ref) is true.
34+
It also handles "absent" parameters such as `bias == false`.
2635
It throws an error whenever:
2736
- `dst` and `src` do not share the same fields (at any level)
2837
- the sizes of leaf nodes are mismatched between `dst` and `src`
38+
- `dst` is a "tied" parameter (e.g. `transpose` of another parameter) and
39+
loaded into multiple times with mismatched source values
2940
3041
```julia
3142
julia> using Flux: loadmodel!
@@ -56,30 +67,29 @@ true
5667
julia> dst[2].bias == src[2].bias
5768
true
5869
```
59-
60-
See [`Flux.loadleaf!`](@ref) for more details on the copy behavior.
61-
62-
!!! warning
63-
This function allows `src` to be a `Params` for backwards-compatibility.
64-
You should avoid using `loadmodel!` this way, because it skips most of the structural
65-
checking used when `src` is also a nested structure. Silent errors may occur.
6670
"""
67-
function loadmodel!(m, xs::Params)
68-
for (p, x) in zip(params(m), xs)
69-
size(p) == size(x) ||
70-
error("Expected param size $(size(p)), got $(size(x))")
71-
copyto!(p, x)
72-
end
73-
end
74-
function loadmodel!(dst, src)
71+
function loadmodel!(dst, src; cache = Base.IdSet())
7572
ldsts, _ = functor(dst)
7673
lsrcs, _ = functor(src)
7774
(keys(ldsts) == keys(lsrcs)) ||
7875
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
7976

8077
err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
8178
foreach(ldsts, lsrcs) do ldst, lsrc
82-
Functors.isleaf(ldst) ? loadleaf!(ldst, lsrc, err) : loadmodel!(ldst, lsrc)
79+
if _parent(ldst) in cache # we already loaded this parameter before
80+
if !_bool_tie_check(ldst, lsrc) # special case to handle tied + absent parameters
81+
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
82+
elseif _tie_check(ldst, lsrc) # the arrays match and we already loaded (or these are not arrays)
83+
return ldst
84+
else # tied dst but mismatched src case
85+
error("Encountered tied destination parameters with untied and mismatched sources.")
86+
end
87+
elseif Functors.isleaf(ldst) # our first time loading this leaf
88+
push!(cache, ldst)
89+
loadleaf!(ldst, lsrc, err)
90+
else # this isn't a leaf
91+
loadmodel!(ldst, lsrc; cache = cache)
92+
end
8393
end
8494

8595
return dst

test/utils.jl

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Flux
22
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
33
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
44
sparse_init, identity_init, stack, unstack, batch, unbatch,
5-
unsqueeze, params, loadmodel!
5+
unsqueeze, params, loadparams!, loadmodel!
66
using StatsBase: var, std
77
using Statistics, LinearAlgebra
88
using Random
@@ -366,26 +366,24 @@ end
366366
@test_skip typeof(l1.bias) === typeof(l2.bias)
367367
end
368368

369-
@testset "loadmodel!" begin
369+
@testset "loadparams!" begin
370370
pars(w, b) = [w, b]
371371
pars(l) = pars(l.weight, l.bias)
372372
pararray(m) = mapreduce(pars, vcat, m)
373373
weights(m) = mapreduce(l -> [l.weight], vcat, m)
374374
@testset "Bias type $bt" for bt in (Flux.zeros32, nobias)
375375
m = dm(bt)
376-
Flux.loadmodel!(m, params(m))
376+
Flux.loadparams!(m, params(m))
377377
testdense(m, bt)
378378
end
379379
end
380380

381381
@testset "loadmodel!(dst, src)" begin
382-
import Flux: loadmodel!, Zeros
383-
384382
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
385383
m2 = Chain(Dense(10, 5), Dense(5, 2))
386384
m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2))
387385
m4 = Chain(Dense(10, 6), Dense(6, 2))
388-
m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), Zeros()), Dense(5, 2)))
386+
m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2)))
389387
m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2)))
390388

391389
loadmodel!(m1, m2)
@@ -408,6 +406,7 @@ end
408406
# size mismatches throw an error
409407
@test_throws DimensionMismatch loadmodel!(m1, m4)
410408

409+
# tests for BatchNorm and Dropout
411410
m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2))
412411
m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1))
413412
m2[2].μ .= rand(Float32, size(m2[2].μ)...)
@@ -484,16 +483,32 @@ end
484483
@test chain2[2].bias != chain1[2].bias
485484

486485
# test shared weights
487-
m1 = Chain(Dense(10 => 5), Dense(5 => 2))
488-
m2 = Chain(Dense(transpose(m1[2].weight)), Dense(permutedims(m1[1].weight)))
489-
m3 = Chain(Dense(m1[1].weight), Dense(m1[2].weight))
490-
m2[2].weight .= 1f0
491-
loadmodel!(m1, m3)
492-
@test m1[2].weight === parent(m2[1].weight)
493-
@test m1[2].weight == transpose(m2[1].weight)
494-
@test m1[1].weight === m3[1].weight
495-
@test m2[2].weight != transpose(m1[1].weight)
496-
@test m3[2].weight == transpose(m2[1].weight)
486+
encoder_dst = Chain(Dense(10 => 5), Dense(5 => 2))
487+
decoder_dst = Chain(Dense(transpose(encoder_dst[2].weight)),
488+
Dense(permutedims(encoder_dst[1].weight)))
489+
encoder_src = Chain(Dense(10 => 5), Dense(5 => 2))
490+
decoder_src = Chain(Dense(transpose(encoder_src[2].weight)),
491+
Dense(5 => 10))
492+
# matched weights are okay
493+
m1 = Chain(encoder_dst, decoder_dst)
494+
m2 = Chain(encoder_src, decoder_src)
495+
loadmodel!(m1, m2)
496+
@test m1[1][2].weight === parent(m1[2][1].weight)
497+
@test m1[1][1].weight == m2[1][1].weight
498+
@test m1[1][1].weight != permutedims(m1[2][2].weight)
499+
# mismatched weights are an error
500+
m2 = Chain(Chain(Dense(10 => 5), Dense(5 => 2)), Chain(Dense(2 => 5), Dense(5 => 10)))
501+
@test_throws ErrorException loadmodel!(m1, m2)
502+
# loading into tied weights with absent parameter is okay when the dst == zero
503+
b = Flux.zeros32(5)
504+
m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b))
505+
m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false))
506+
loadmodel!(m1, m2)
507+
@test m1[1].bias === m1[2].bias
508+
@test iszero(m1[1].bias)
509+
# loading into tied weights with absent parameter is bad when the dst != zero
510+
m2[1].bias .= 1
511+
@test_throws ErrorException loadmodel!(m1, m2)
497512
end
498513

499514
@testset "destructure" begin

0 commit comments

Comments
 (0)