@@ -2,7 +2,7 @@ using Flux
2
2
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
3
3
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
4
4
sparse_init, identity_init, stack, unstack, batch, unbatch,
5
- unsqueeze, params, loadmodel!
5
+ unsqueeze, params, loadparams!, loadmodel!
6
6
using StatsBase: var, std
7
7
using Statistics, LinearAlgebra
8
8
using Random
@@ -366,26 +366,24 @@ end
366
366
@test_skip typeof (l1. bias) === typeof (l2. bias)
367
367
end
368
368
369
- @testset " loadmodel !" begin
369
+ @testset " loadparams !" begin
370
370
pars (w, b) = [w, b]
371
371
pars (l) = pars (l. weight, l. bias)
372
372
pararray (m) = mapreduce (pars, vcat, m)
373
373
weights (m) = mapreduce (l -> [l. weight], vcat, m)
374
374
@testset " Bias type $bt " for bt in (Flux. zeros32, nobias)
375
375
m = dm (bt)
376
- Flux. loadmodel ! (m, params (m))
376
+ Flux. loadparams ! (m, params (m))
377
377
testdense (m, bt)
378
378
end
379
379
end
380
380
381
381
@testset " loadmodel!(dst, src)" begin
382
- import Flux: loadmodel!, Zeros
383
-
384
382
m1 = Chain (Dense (10 , 5 ), Dense (5 , 2 , relu))
385
383
m2 = Chain (Dense (10 , 5 ), Dense (5 , 2 ))
386
384
m3 = Chain (Conv ((3 , 3 ), 3 => 16 ), Dense (5 , 2 ))
387
385
m4 = Chain (Dense (10 , 6 ), Dense (6 , 2 ))
388
- m5 = Chain (Dense (10 , 5 ), Parallel (+ , Dense (Flux. ones32 (2 , 5 ), Zeros () ), Dense (5 , 2 )))
386
+ m5 = Chain (Dense (10 , 5 ), Parallel (+ , Dense (Flux. ones32 (2 , 5 ), false ), Dense (5 , 2 )))
389
387
m6 = Chain (Dense (10 , 5 ), Parallel (+ , Dense (5 , 2 ), Dense (5 , 2 )))
390
388
391
389
loadmodel! (m1, m2)
408
406
# size mismatches throw an error
409
407
@test_throws DimensionMismatch loadmodel! (m1, m4)
410
408
409
+ # tests for BatchNorm and Dropout
411
410
m1 = Chain (Conv ((3 , 3 ), 3 => 16 ), BatchNorm (16 ), Flux. flatten, Dropout (0.2 ))
412
411
m2 = Chain (Conv ((3 , 3 ), 3 => 16 ), BatchNorm (16 ), x -> reshape (x, :, size (x)[end ]), Dropout (0.1 ))
413
412
m2[2 ]. μ .= rand (Float32, size (m2[2 ]. μ)... )
@@ -484,16 +483,32 @@ end
484
483
@test chain2[2 ]. bias != chain1[2 ]. bias
485
484
486
485
# test shared weights
487
- m1 = Chain (Dense (10 => 5 ), Dense (5 => 2 ))
488
- m2 = Chain (Dense (transpose (m1[2 ]. weight)), Dense (permutedims (m1[1 ]. weight)))
489
- m3 = Chain (Dense (m1[1 ]. weight), Dense (m1[2 ]. weight))
490
- m2[2 ]. weight .= 1f0
491
- loadmodel! (m1, m3)
492
- @test m1[2 ]. weight === parent (m2[1 ]. weight)
493
- @test m1[2 ]. weight == transpose (m2[1 ]. weight)
494
- @test m1[1 ]. weight === m3[1 ]. weight
495
- @test m2[2 ]. weight != transpose (m1[1 ]. weight)
496
- @test m3[2 ]. weight == transpose (m2[1 ]. weight)
486
+ encoder_dst = Chain (Dense (10 => 5 ), Dense (5 => 2 ))
487
+ decoder_dst = Chain (Dense (transpose (encoder_dst[2 ]. weight)),
488
+ Dense (permutedims (encoder_dst[1 ]. weight)))
489
+ encoder_src = Chain (Dense (10 => 5 ), Dense (5 => 2 ))
490
+ decoder_src = Chain (Dense (transpose (encoder_src[2 ]. weight)),
491
+ Dense (5 => 10 ))
492
+ # matched weights are okay
493
+ m1 = Chain (encoder_dst, decoder_dst)
494
+ m2 = Chain (encoder_src, decoder_src)
495
+ loadmodel! (m1, m2)
496
+ @test m1[1 ][2 ]. weight === parent (m1[2 ][1 ]. weight)
497
+ @test m1[1 ][1 ]. weight == m2[1 ][1 ]. weight
498
+ @test m1[1 ][1 ]. weight != permutedims (m1[2 ][2 ]. weight)
499
+ # mismatched weights are an error
500
+ m2 = Chain (Chain (Dense (10 => 5 ), Dense (5 => 2 )), Chain (Dense (2 => 5 ), Dense (5 => 10 )))
501
+ @test_throws ErrorException loadmodel! (m1, m2)
502
+ # loading into tied weights with absent parameter is okay when the dst == zero
503
+ b = Flux. zeros32 (5 )
504
+ m1 = Chain (Dense (10 => 5 ; bias = b), Dense (5 => 5 ; bias = b))
505
+ m2 = Chain (Dense (10 => 5 ; bias = Flux. zeros32 (5 )), Dense (5 => 5 ; bias = false ))
506
+ loadmodel! (m1, m2)
507
+ @test m1[1 ]. bias === m1[2 ]. bias
508
+ @test iszero (m1[1 ]. bias)
509
+ # loading into tied weights with absent parameter is bad when the dst != zero
510
+ m2[1 ]. bias .= 1
511
+ @test_throws ErrorException loadmodel! (m1, m2)
497
512
end
498
513
499
514
@testset " destructure" begin
0 commit comments