Skip to content

Commit c831955

Browse files
committed
Combine _bool_tie_check and _tie_check.
1 parent 29662b2 commit c831955

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

src/loading.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@ end
1717
_parent(x) = x
1818
_parent(x::AbstractArray) = parent(x)
1919

20-
_tie_check(dst::AbstractArray, src::AbstractArray) = dst == src
20+
_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) ||
21+
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
22+
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) ||
23+
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
24+
_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) ||
25+
error("Encountered tied destination parameters with untied and mismatched sources.")
2126
_tie_check(dst, src) = true
2227

23-
_bool_tie_check(dst::Bool, src::AbstractArray) = iszero(dst)
24-
_bool_tie_check(dst::AbstractArray, src::Bool) = iszero(dst) && iszero(src)
2528
_bool_tie_check(dst, src) = true
2629

2730
"""
@@ -77,13 +80,7 @@ function loadmodel!(dst, src; cache = Base.IdSet())
7780
err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
7881
foreach(ldsts, lsrcs) do ldst, lsrc
7982
if _parent(ldst) in cache # we already loaded this parameter before
80-
if !_bool_tie_check(ldst, lsrc) # special case to handle tied + absent parameters
81-
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
82-
elseif _tie_check(ldst, lsrc) # the arrays match and we already loaded (or these are not arrays)
83-
return ldst
84-
else # tied dst but mismatched src case
85-
error("Encountered tied destination parameters with untied and mismatched sources.")
86-
end
83+
_tie_check(ldst, lsrc) && return ldst
8784
elseif Functors.isleaf(ldst) # our first time loading this leaf
8885
push!(cache, ldst)
8986
loadleaf!(ldst, lsrc, err)

0 commit comments

Comments
 (0)