@@ -483,21 +483,16 @@ end
483
483
@test chain2[2 ]. bias != chain1[2 ]. bias
484
484
485
485
# test shared weights
486
- encoder_dst = Chain (Dense (10 => 5 ), Dense (5 => 2 ))
487
- decoder_dst = Chain (Dense (transpose (encoder_dst[2 ]. weight)),
488
- Dense (permutedims (encoder_dst[1 ]. weight)))
489
- encoder_src = Chain (Dense (10 => 5 ), Dense (5 => 2 ))
490
- decoder_src = Chain (Dense (transpose (encoder_src[2 ]. weight)),
491
- Dense (5 => 10 ))
486
+ shared_dst = Dense (10 => 10 )
487
+ shared_src = Dense (10 => 10 )
492
488
# matched weights are okay
493
- m1 = Chain (encoder_dst, decoder_dst )
494
- m2 = Chain (encoder_src, decoder_src )
489
+ m1 = Chain (shared_dst, Dense (shared_dst . weight) )
490
+ m2 = Chain (shared_src, Dense (shared_src . weight) )
495
491
loadmodel! (m1, m2)
496
- @test m1[1 ][2 ]. weight === parent (m1[2 ][1 ]. weight)
497
- @test m1[1 ][1 ]. weight == m2[1 ][1 ]. weight
498
- @test m1[1 ][1 ]. weight != permutedims (m1[2 ][2 ]. weight)
492
+ @test m1[1 ]. weight === m1[2 ]. weight
493
+ @test m1[1 ]. weight == m2[2 ]. weight
499
494
# mismatched weights are an error
500
- m2 = Chain (Chain ( Dense (10 => 5 ), Dense (5 => 2 )), Chain ( Dense ( 2 => 5 ), Dense ( 5 => 10 ) ))
495
+ m2 = Chain (Dense (10 => 10 ), Dense (10 => 10 ))
501
496
@test_throws ErrorException loadmodel! (m1, m2)
502
497
# loading into tied weights with absent parameter is okay when the dst == zero
503
498
b = Flux. zeros32 (5 )
0 commit comments