Skip to content

Commit 5546e1d

Browse files
authored
Merge pull request #2041 from darsnack/load-filter
2 parents dedc7ce + 39ba38f commit 5546e1d

File tree

2 files changed

+45
-25
lines changed

2 files changed

+45
-25
lines changed

src/loading.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ _tie_check(dst, src) = true
2828

2929
_bool_tie_check(dst, src) = true
3030

31+
_filter_children(f, children::NamedTuple) =
32+
NamedTuple(filter(kv -> f(kv[2]), pairs(children)))
33+
_filter_children(f, children) = filter(f, children)
34+
3135
"""
3236
loadmodel!(dst, src)
3337
@@ -41,7 +45,7 @@ Zero bias vectors and `bias=false` are considered equivalent
4145
4246
# Examples
4347
```julia
44-
julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh)), Dense(2 => 1; bias = [1f0]))
48+
julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0]))
4549
Chain(
4650
Dense(5 => 2, tanh), # 12 parameters
4751
Dense(2 => 1), # 3 parameters
@@ -77,9 +81,9 @@ however, attempting to copy a non-zero array to an inactive parameter will throw
7781
Likewise, copying a `src` value of `false` to any `dst` array is valid,
7882
but copying a `src` value of `true` will error.
7983
"""
80-
function loadmodel!(dst, src; cache = Base.IdSet())
81-
ldsts, _ = functor(dst)
82-
lsrcs, _ = functor(src)
84+
function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet())
85+
ldsts = _filter_children(filter, functor(dst)[1])
86+
lsrcs = _filter_children(filter, functor(src)[1])
8387
(keys(ldsts) == keys(lsrcs)) ||
8488
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
8589

@@ -91,7 +95,7 @@ function loadmodel!(dst, src; cache = Base.IdSet())
9195
push!(cache, ldst)
9296
loadleaf!(ldst, lsrc, err)
9397
else # this isn't a leaf
94-
loadmodel!(ldst, lsrc; cache = cache)
98+
loadmodel!(ldst, lsrc; filter = filter, cache = cache)
9599
end
96100
end
97101

test/utils.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,42 @@ end
508508
# loading into tied weights with absent parameter is bad when the dst != zero
509509
m2[1].bias .= 1
510510
@test_throws ErrorException loadmodel!(m1, m2)
511+
512+
@testset "loadmodel! & filter" begin
513+
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
514+
m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2))
515+
m3 = Chain(Dense(10, 5), Dense(5, 2, relu))
516+
517+
# this will not error cause Dropout is skipped
518+
loadmodel!(m1, m2; filter = x -> !(x isa Dropout))
519+
@test m1[1].weight == m2[1].weight
520+
@test m1[2].weight == m2[3].weight
521+
522+
# this will not error cause Dropout is skipped
523+
loadmodel!(m2, m3; filter = x -> !(x isa Dropout))
524+
@test m3[1].weight == m2[1].weight
525+
@test m3[2].weight == m2[3].weight
526+
end
527+
528+
@testset "loadmodel! & absent bias" begin
529+
m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1))
530+
m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1))
531+
m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
532+
533+
Flux.loadmodel!(m1, m2)
534+
@test m1[1].bias == 7:9
535+
@test sum(m1[1].weight) == 21
536+
537+
# load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it
538+
m1 = Flux.loadmodel!(m1, m0)
539+
@test iszero(m1[1].bias)
540+
@test sum(m1[1].weight) == 6 # written before error
541+
542+
# load into a model without bias -- should it ignore the parameter which has no home, or error?
543+
m0 = Flux.loadmodel!(m0, m2)
544+
@test iszero(m0[1].bias) # obviously unchanged
545+
@test sum(m0[1].weight) == 21
546+
end
511547
end
512548

513549
@testset "destructure" begin
@@ -529,26 +565,6 @@ end
529565
end
530566
end
531567

532-
@testset "loadmodel! & absent bias" begin
533-
m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1))
534-
m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1))
535-
m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
536-
537-
Flux.loadmodel!(m1, m2)
538-
@test m1[1].bias == 7:9
539-
@test sum(m1[1].weight) == 21
540-
541-
# load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it
542-
m1 = Flux.loadmodel!(m1, m0)
543-
@test iszero(m1[1].bias)
544-
@test sum(m1[1].weight) == 6 # written before error
545-
546-
# load into a model without bias -- should it ignore the parameter which has no home, or error?
547-
m0 = Flux.loadmodel!(m0, m2)
548-
@test iszero(m0[1].bias) # obviously unchanged
549-
@test sum(m0[1].weight) == 21
550-
end
551-
552568
@testset "Train and test mode" begin
553569
mutable struct DummyLayer
554570
testing::Bool

0 commit comments

Comments
 (0)