|
504 | 504 | # loading into tied weights with absent parameter is bad when the dst != zero
|
505 | 505 | m2[1].bias .= 1
|
506 | 506 | @test_throws ErrorException loadmodel!(m1, m2)
|
| 507 | + |
| 508 | + @testset "loadmodel! & filter" begin |
| 509 | + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) |
| 510 | + m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) |
| 511 | + m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) |
| 512 | + |
| 513 | + # this will not error cause Dropout is skipped |
| 514 | + loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) |
| 515 | + @test m1[1].weight == m2[1].weight |
| 516 | + @test m1[2].weight == m2[3].weight |
| 517 | + |
| 518 | + # this will not error cause Dropout is skipped |
| 519 | + loadmodel!(m2, m3; filter = x -> !(x isa Dropout)) |
| 520 | + @test m3[1].weight == m2[1].weight |
| 521 | + @test m3[2].weight == m2[3].weight |
| 522 | + end |
| 523 | + |
| 524 | + @testset "loadmodel! & absent bias" begin |
| 525 | + m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) |
| 526 | + m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) |
| 527 | + m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) |
| 528 | + |
| 529 | + Flux.loadmodel!(m1, m2) |
| 530 | + @test m1[1].bias == 7:9 |
| 531 | + @test sum(m1[1].weight) == 21 |
| 532 | + |
| 533 | + # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it |
| 534 | + m1 = Flux.loadmodel!(m1, m0) |
| 535 | + @test iszero(m1[1].bias) |
| 536 | + @test sum(m1[1].weight) == 6 # written before error |
| 537 | + |
| 538 | + # load into a model without bias -- should it ignore the parameter which has no home, or error? |
| 539 | + m0 = Flux.loadmodel!(m0, m2) |
| 540 | + @test iszero(m0[1].bias) # obviously unchanged |
| 541 | + @test sum(m0[1].weight) == 21 |
| 542 | + end |
507 | 543 | end
|
508 | 544 |
|
509 | 545 | @testset "destructure" begin
|
|
525 | 561 | end
|
526 | 562 | end
|
527 | 563 |
|
528 |
| -@testset "loadmodel! & absent bias" begin |
529 |
| - m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) |
530 |
| - m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) |
531 |
| - m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) |
532 |
| - |
533 |
| - Flux.loadmodel!(m1, m2) |
534 |
| - @test m1[1].bias == 7:9 |
535 |
| - @test sum(m1[1].weight) == 21 |
536 |
| - |
537 |
| - # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it |
538 |
| - m1 = Flux.loadmodel!(m1, m0) |
539 |
| - @test iszero(m1[1].bias) |
540 |
| - @test sum(m1[1].weight) == 6 # written before error |
541 |
| - |
542 |
| - # load into a model without bias -- should it ignore the parameter which has no home, or error? |
543 |
| - m0 = Flux.loadmodel!(m0, m2) |
544 |
| - @test iszero(m0[1].bias) # obviously unchanged |
545 |
| - @test sum(m0[1].weight) == 21 |
546 |
| -end |
547 |
| - |
548 | 564 | @testset "Train and test mode" begin
|
549 | 565 | mutable struct DummyLayer
|
550 | 566 | testing::Bool
|
|
0 commit comments