|
1 | 1 | using Flux.Train
|
2 | 2 | using Zygote: Params, gradient
|
3 | 3 |
|
4 |
| -import FillArrays, ComponentArrays |
| 4 | +import Optimisers, FillArrays, ComponentArrays |
5 | 5 |
|
6 | 6 | using Test
|
7 | 7 | using Random
|
|
29 | 29 | w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
|
30 | 30 | @testset for opt in [Descent(0.1), Adam()]
|
31 | 31 | @test opt isa FluxState
|
32 |
| - w′ = copy(w2) |
33 |
| - b = zeros(10) |
| 32 | + @test opt.state isa Missing |
| 33 | + |
| 34 | + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) |
| 35 | + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) |
| 36 | + @test loss(model, rand(10, 10)) > 1 |
| 37 | + |
| 38 | + train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) |
| 39 | + @test loss(model, rand(10, 10)) < 0.01 |
| 40 | + @test opt.state isa NamedTuple |
| 41 | + end |
| 42 | + |
| 43 | + # Test 3-arg `train!` method: |
| 44 | + @testset for opt in [Descent(0.1), Adam()] |
| 45 | + @test opt isa FluxState |
| 46 | + @test opt.state isa Missing |
| 47 | + |
| 48 | + loss(m) = let x = rand(10) |
| 49 | + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) |
| 50 | + end |
| 51 | + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) |
| 52 | + @test loss(model) > 1 |
| 53 | + |
| 54 | + for i in 1:10^5 |
| 55 | + train!(loss, model, opt) |
| 56 | + end |
| 57 | + @test loss(model) < 0.01 |
| 58 | + @test opt.state isa NamedTuple |
| 59 | + end |
| 60 | + |
| 61 | + # Test direct use of Optimisers.jl rule, only really OK for `Descent`: |
| 62 | + @testset for opt in [Optimisers.Descent(0.1), Optimisers.Adam()] |
| 63 | + @test opt isa Optimisers.AbstractRule |
34 | 64 | loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
|
35 |
| - model = (weight=w′, bias=b, ignore=nothing) |
| 65 | + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) |
36 | 66 | @test loss(model, rand(10, 10)) > 1
|
37 | 67 | train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
|
38 | 68 | @test loss(model, rand(10, 10)) < 0.01
|
|
0 commit comments