Skip to content

Commit 37c9759

Browse files
committed
Add tests for AD backends
1 parent fbde477 commit 37c9759

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
5353
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
5454
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5555
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
56+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
5657

5758
[targets]
58-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
59+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker"]

test/optimise.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Flux.Optimise
2-
using Flux.Optimise: runall
2+
using Flux.Optimise: runall, ZygoteImplicitBackend, ZygoteExplicitBackend
33
using Flux: Params, gradient
44
import FillArrays, ComponentArrays
55
using Test
@@ -45,6 +45,36 @@ end
4545
end
4646
end
4747

48+
@testset "AD backends" begin
49+
# this is hack to make Tracker work
50+
AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad
51+
AD.value_and_gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...)
52+
53+
function _loss_and_model(::ZygoteImplicitBackend, loss, model)
54+
return () -> loss(model), Flux.params(model)
55+
end
56+
_loss_and_model(ad, loss, model) = loss, model
57+
58+
function _check_gradient(::ZygoteImplicitBackend, model, grad)
59+
return grad[model[1].weight] == 2 .* Flux.ones32(5, 10) &&
60+
grad[model[2].weight] == 10 .* Flux.ones32(2, 5)
61+
end
62+
function _check_gradient(ad, model, grad)
63+
return grad[1].layers[1].weight == 2 .* Flux.ones32(5, 10) &&
64+
grad[1].layers[2].weight == 10 .* Flux.ones32(2, 5)
65+
end
66+
67+
@testset for ad in [ZygoteImplicitBackend(), ZygoteExplicitBackend(), AD.TrackerBackend()]
68+
model = Chain(Dense(Flux.ones32(5, 10), false), Dense(Flux.ones32(2, 5), false))
69+
x = Flux.ones32(10)
70+
_loss, _model = _loss_and_model(ad, m -> sum(m(x)), model)
71+
val, grad = AD.value_and_gradient(ad, _loss, _model)
72+
@test val == sum(model(x))
73+
@test _check_gradient(ad, model, grad)
74+
@test _check_gradient(ad, model, AD.gradient(ad, _loss, _model))
75+
end
76+
end
77+
4878
@testset "Training Loop" begin
4979
i = 0
5080
l = 1

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ using Flux: params
55
using Test
66
using Random, Statistics, LinearAlgebra
77
using IterTools: ncycle
8+
import Tracker
89
using Zygote
10+
using AbstractDifferentiation
911
using CUDA
1012

1113
Random.seed!(0)

0 commit comments

Comments
 (0)