|
1 | 1 | using Flux.Optimise
|
2 |
| -using Flux.Optimise: runall |
| 2 | +using Flux.Optimise: runall, ZygoteImplicitBackend, ZygoteExplicitBackend |
3 | 3 | using Flux: Params, gradient
|
4 | 4 | import FillArrays, ComponentArrays
|
5 | 5 | using Test
|
|
45 | 45 | end
|
46 | 46 | end
|
47 | 47 |
|
| 48 | +@testset "AD backends" begin |
| 49 | + # this is hack to make Tracker work |
| 50 | + AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad |
| 51 | + AD.value_and_gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...) |
| 52 | + |
| 53 | + function _loss_and_model(::ZygoteImplicitBackend, loss, model) |
| 54 | + return () -> loss(model), Flux.params(model) |
| 55 | + end |
| 56 | + _loss_and_model(ad, loss, model) = loss, model |
| 57 | + |
| 58 | + function _check_gradient(::ZygoteImplicitBackend, model, grad) |
| 59 | + return grad[model[1].weight] == 2 .* Flux.ones32(5, 10) && |
| 60 | + grad[model[2].weight] == 10 .* Flux.ones32(2, 5) |
| 61 | + end |
| 62 | + function _check_gradient(ad, model, grad) |
| 63 | + return grad[1].layers[1].weight == 2 .* Flux.ones32(5, 10) && |
| 64 | + grad[1].layers[2].weight == 10 .* Flux.ones32(2, 5) |
| 65 | + end |
| 66 | + |
| 67 | + @testset for ad in [ZygoteImplicitBackend(), ZygoteExplicitBackend(), AD.TrackerBackend()] |
| 68 | + model = Chain(Dense(Flux.ones32(5, 10), false), Dense(Flux.ones32(2, 5), false)) |
| 69 | + x = Flux.ones32(10) |
| 70 | + _loss, _model = _loss_and_model(ad, m -> sum(m(x)), model) |
| 71 | + val, grad = AD.value_and_gradient(ad, _loss, _model) |
| 72 | + @test val == sum(model(x)) |
| 73 | + @test _check_gradient(ad, model, grad) |
| 74 | + @test _check_gradient(ad, model, AD.gradient(ad, _loss, _model)) |
| 75 | + end |
| 76 | +end |
| 77 | + |
48 | 78 | @testset "Training Loop" begin
|
49 | 79 | i = 0
|
50 | 80 | l = 1
|
|
0 commit comments