Skip to content

Commit 0d55c00

Browse files
committed
Remove _parent
1 parent c831955 commit 0d55c00

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

src/loading.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
1414
copyto!(dst, src)
1515
end
1616

17-
_parent(x) = x
18-
_parent(x::AbstractArray) = parent(x)
19-
2017
_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) ||
2118
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
2219
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) ||
@@ -79,7 +76,7 @@ function loadmodel!(dst, src; cache = Base.IdSet())
7976

8077
err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
8178
foreach(ldsts, lsrcs) do ldst, lsrc
82-
if _parent(ldst) in cache # we already loaded this parameter before
79+
if ldst in cache # we already loaded this parameter before
8380
_tie_check(ldst, lsrc) && return ldst
8481
elseif Functors.isleaf(ldst) # our first time loading this leaf
8582
push!(cache, ldst)

test/utils.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -483,21 +483,16 @@ end
483483
@test chain2[2].bias != chain1[2].bias
484484

485485
# test shared weights
486-
encoder_dst = Chain(Dense(10 => 5), Dense(5 => 2))
487-
decoder_dst = Chain(Dense(transpose(encoder_dst[2].weight)),
488-
Dense(permutedims(encoder_dst[1].weight)))
489-
encoder_src = Chain(Dense(10 => 5), Dense(5 => 2))
490-
decoder_src = Chain(Dense(transpose(encoder_src[2].weight)),
491-
Dense(5 => 10))
486+
shared_dst = Dense(10 => 10)
487+
shared_src = Dense(10 => 10)
492488
# matched weights are okay
493-
m1 = Chain(encoder_dst, decoder_dst)
494-
m2 = Chain(encoder_src, decoder_src)
489+
m1 = Chain(shared_dst, Dense(shared_dst.weight))
490+
m2 = Chain(shared_src, Dense(shared_src.weight))
495491
loadmodel!(m1, m2)
496-
@test m1[1][2].weight === parent(m1[2][1].weight)
497-
@test m1[1][1].weight == m2[1][1].weight
498-
@test m1[1][1].weight != permutedims(m1[2][2].weight)
492+
@test m1[1].weight === m1[2].weight
493+
@test m1[1].weight == m2[2].weight
499494
# mismatched weights are an error
500-
m2 = Chain(Chain(Dense(10 => 5), Dense(5 => 2)), Chain(Dense(2 => 5), Dense(5 => 10)))
495+
m2 = Chain(Dense(10 => 10), Dense(10 => 10))
501496
@test_throws ErrorException loadmodel!(m1, m2)
502497
# loading into tied weights with absent parameter is okay when the dst == zero
503498
b = Flux.zeros32(5)

0 commit comments

Comments
 (0)