@@ -33,7 +33,7 @@ _filter_children(f, children::NamedTuple) =
33
33
_filter_children (f, children) = filter (f, children)
34
34
35
35
"""
36
- loadmodel!(dst, src)
36
+ loadmodel!(dst, src; filter = _ -> true )
37
37
38
38
Copy all the parameters (trainable and non-trainable) from `src` into `dst`.
39
39
@@ -43,9 +43,12 @@ Non-array elements (such as activation functions) are not copied and need not ma
43
43
Zero bias vectors and `bias=false` are considered equivalent
44
44
(see extended help for more details).
45
45
46
+ Specify the predicate function `filter` to control what is recursed.
47
+ A child node `x` in either `dst` and `src` is skipped when `filter(x) == false`.
48
+
46
49
# Examples
47
50
```julia
48
- julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh) ), Dense(2 => 1; bias = [1f0]))
51
+ julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh ), Dense(2 => 1; bias = [1f0]))
49
52
Chain(
50
53
Dense(5 => 2, tanh), # 12 parameters
51
54
Dense(2 => 1), # 3 parameters
63
66
64
67
julia> iszero(dst[2].bias)
65
68
true
69
+
70
+ julia> src = Chain(Dense(5 => 2), Dropout(0.2), Dense(2 => 1))
71
+ Chain(
72
+ Dense(5 => 2), # 12 parameters
73
+ Dropout(0.2),
74
+ Dense(2 => 1), # 3 parameters
75
+ ) # Total: 4 arrays, 15 parameters, 348 bytes.
76
+
77
+ julia> Flux.loadmodel!(dst, src; filter = x -> !(x isa Dropout)) # skips loading Dropout
78
+ Chain(
79
+ Dense(5 => 2, tanh), # 12 parameters
80
+ Dense(2 => 1), # 3 parameters
81
+ ) # Total: 4 arrays, 15 parameters, 316 bytes.
66
82
```
67
83
68
84
# Extended help
0 commit comments