Skip to content

Commit 4aae5cc

Browse files
committed
Add ability to filter loadmodel! recursion
1 parent f9b95c4 commit 4aae5cc

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

src/loading.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ _tie_check(dst, src) = true
2828

2929
_bool_tie_check(dst, src) = true
3030

31+
_filter_children(f, children::NamedTuple) =
32+
NamedTuple(filter(kv -> f(kv[2]), pairs(children)))
33+
_filter_children(f, children) = filter(f, children)
34+
3135
"""
3236
loadmodel!(dst, src)
3337
@@ -77,9 +81,9 @@ however, attempting to copy a non-zero array to an inactive parameter will throw
7781
Likewise, copying a `src` value of `false` to any `dst` array is valid,
7882
but copying a `src` value of `true` will error.
7983
"""
80-
function loadmodel!(dst, src; cache = Base.IdSet())
81-
ldsts, _ = functor(dst)
82-
lsrcs, _ = functor(src)
84+
function loadmodel!(dst, src; filter = Returns(true), cache = Base.IdSet())
85+
ldsts = _filter_children(filter, functor(dst)[1])
86+
lsrcs = _filter_children(filter, functor(src)[1])
8387
(keys(ldsts) == keys(lsrcs)) ||
8488
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
8589

@@ -91,7 +95,7 @@ function loadmodel!(dst, src; cache = Base.IdSet())
9195
push!(cache, ldst)
9296
loadleaf!(ldst, lsrc, err)
9397
else # this isn't a leaf
94-
loadmodel!(ldst, lsrc; cache = cache)
98+
loadmodel!(ldst, lsrc; filter = filter, cache = cache)
9599
end
96100
end
97101

test/utils.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,42 @@ end
504504
# loading into tied weights with absent parameter is bad when the dst != zero
505505
m2[1].bias .= 1
506506
@test_throws ErrorException loadmodel!(m1, m2)
507+
508+
@testset "loadmodel! & filter" begin
509+
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
510+
m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2))
511+
m3 = Chain(Dense(10, 5), Dense(5, 2, relu))
512+
513+
# this will not error cause Dropout is skipped
514+
loadmodel!(m1, m2; filter = x -> !(x isa Dropout))
515+
@test m1[1].weight == m2[1].weight
516+
@test m1[2].weight == m2[3].weight
517+
518+
# this will not error cause Dropout is skipped
519+
loadmodel!(m2, m3; filter = x -> !(x isa Dropout))
520+
@test m3[1].weight == m2[1].weight
521+
@test m3[2].weight == m2[3].weight
522+
end
523+
524+
@testset "loadmodel! & absent bias" begin
525+
m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1))
526+
m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1))
527+
m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
528+
529+
Flux.loadmodel!(m1, m2)
530+
@test m1[1].bias == 7:9
531+
@test sum(m1[1].weight) == 21
532+
533+
# load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it
534+
m1 = Flux.loadmodel!(m1, m0)
535+
@test iszero(m1[1].bias)
536+
@test sum(m1[1].weight) == 6 # written before error
537+
538+
# load into a model without bias -- should it ignore the parameter which has no home, or error?
539+
m0 = Flux.loadmodel!(m0, m2)
540+
@test iszero(m0[1].bias) # obviously unchanged
541+
@test sum(m0[1].weight) == 21
542+
end
507543
end
508544

509545
@testset "destructure" begin
@@ -525,26 +561,6 @@ end
525561
end
526562
end
527563

528-
@testset "loadmodel! & absent bias" begin
529-
m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1))
530-
m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1))
531-
m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
532-
533-
Flux.loadmodel!(m1, m2)
534-
@test m1[1].bias == 7:9
535-
@test sum(m1[1].weight) == 21
536-
537-
# load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it
538-
m1 = Flux.loadmodel!(m1, m0)
539-
@test iszero(m1[1].bias)
540-
@test sum(m1[1].weight) == 6 # written before error
541-
542-
# load into a model without bias -- should it ignore the parameter which has no home, or error?
543-
m0 = Flux.loadmodel!(m0, m2)
544-
@test iszero(m0[1].bias) # obviously unchanged
545-
@test sum(m0[1].weight) == 21
546-
end
547-
548564
@testset "Train and test mode" begin
549565
mutable struct DummyLayer
550566
testing::Bool

0 commit comments

Comments
 (0)