|
508 | 508 | # loading into tied weights with absent parameter is bad when the dst != zero
|
509 | 509 | m2[1].bias .= 1
|
510 | 510 | @test_throws ErrorException loadmodel!(m1, m2)
|
| 511 | + |
| 512 | + @testset "loadmodel! & filter" begin |
| 513 | + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) |
| 514 | + m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) |
| 515 | + m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) |
| 516 | + |
| 517 | + # this will not error cause Dropout is skipped |
| 518 | + loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) |
| 519 | + @test m1[1].weight == m2[1].weight |
| 520 | + @test m1[2].weight == m2[3].weight |
| 521 | + |
| 522 | + # this will not error cause Dropout is skipped |
| 523 | + loadmodel!(m2, m3; filter = x -> !(x isa Dropout)) |
| 524 | + @test m3[1].weight == m2[1].weight |
| 525 | + @test m3[2].weight == m2[3].weight |
| 526 | + end |
| 527 | + |
| 528 | + @testset "loadmodel! & absent bias" begin |
| 529 | + m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) |
| 530 | + m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) |
| 531 | + m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) |
| 532 | + |
| 533 | + Flux.loadmodel!(m1, m2) |
| 534 | + @test m1[1].bias == 7:9 |
| 535 | + @test sum(m1[1].weight) == 21 |
| 536 | + |
| 537 | + # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it |
| 538 | + m1 = Flux.loadmodel!(m1, m0) |
| 539 | + @test iszero(m1[1].bias) |
| 540 | + @test sum(m1[1].weight) == 6 # written before error |
| 541 | + |
| 542 | + # load into a model without bias -- should it ignore the parameter which has no home, or error? |
| 543 | + m0 = Flux.loadmodel!(m0, m2) |
| 544 | + @test iszero(m0[1].bias) # obviously unchanged |
| 545 | + @test sum(m0[1].weight) == 21 |
| 546 | + end |
511 | 547 | end
|
512 | 548 |
|
513 | 549 | @testset "destructure" begin
|
|
529 | 565 | end
|
530 | 566 | end
|
531 | 567 |
|
532 |
| -@testset "loadmodel! & absent bias" begin |
533 |
| - m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) |
534 |
| - m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) |
535 |
| - m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) |
536 |
| - |
537 |
| - Flux.loadmodel!(m1, m2) |
538 |
| - @test m1[1].bias == 7:9 |
539 |
| - @test sum(m1[1].weight) == 21 |
540 |
| - |
541 |
| - # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it |
542 |
| - m1 = Flux.loadmodel!(m1, m0) |
543 |
| - @test iszero(m1[1].bias) |
544 |
| - @test sum(m1[1].weight) == 6 # written before error |
545 |
| - |
546 |
| - # load into a model without bias -- should it ignore the parameter which has no home, or error? |
547 |
| - m0 = Flux.loadmodel!(m0, m2) |
548 |
| - @test iszero(m0[1].bias) # obviously unchanged |
549 |
| - @test sum(m0[1].weight) == 21 |
550 |
| -end |
551 |
| - |
552 | 568 | @testset "Train and test mode" begin
|
553 | 569 | mutable struct DummyLayer
|
554 | 570 | testing::Bool
|
|
0 commit comments