|
17 | 17 | _parent(x) = x
|
18 | 18 | _parent(x::AbstractArray) = parent(x)
|
19 | 19 |
|
20 |
| -_tie_check(dst::AbstractArray, src::AbstractArray) = dst == src |
| 20 | +_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) || |
| 21 | + error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") |
| 22 | +_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) || |
| 23 | + error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") |
| 24 | +_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) || |
| 25 | + error("Encountered tied destination parameters with untied and mismatched sources.") |
21 | 26 | _tie_check(dst, src) = true
|
22 | 27 |
|
23 |
| -_bool_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) |
24 |
| -_bool_tie_check(dst::AbstractArray, src::Bool) = iszero(dst) && iszero(src) |
25 | 28 | _bool_tie_check(dst, src) = true
|
26 | 29 |
|
27 | 30 | """
|
@@ -77,13 +80,7 @@ function loadmodel!(dst, src; cache = Base.IdSet())
|
77 | 80 | err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
|
78 | 81 | foreach(ldsts, lsrcs) do ldst, lsrc
|
79 | 82 | if _parent(ldst) in cache # we already loaded this parameter before
|
80 |
| - if !_bool_tie_check(ldst, lsrc) # special case to handle tied + absent parameters |
81 |
| - error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") |
82 |
| - elseif _tie_check(ldst, lsrc) # the arrays match and we already loaded (or these are not arrays) |
83 |
| - return ldst |
84 |
| - else # tied dst but mismatched src case |
85 |
| - error("Encountered tied destination parameters with untied and mismatched sources.") |
86 |
| - end |
| 83 | + _tie_check(ldst, lsrc) && return ldst |
87 | 84 | elseif Functors.isleaf(ldst) # our first time loading this leaf
|
88 | 85 | push!(cache, ldst)
|
89 | 86 | loadleaf!(ldst, lsrc, err)
|
|
0 commit comments