Skip to content

Commit 80c596a

Browse files
add a test
1 parent 9088eb7 commit 80c596a

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ DDEProblemLibrary = "f42792ee-6ffc-4e2a-ae83-8ee2f22de800"
33
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
44
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
55
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
6+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
67
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
78
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
89
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

test/downstream/mooncake.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using Mooncake, OrdinaryDiffEq, StaticArrays
2+
3+
function lorenz!(du, u, p, t)
4+
du[1] = 10.0(u[2] - u[1])
5+
du[2] = u[1] * (28.0 - u[3]) - u[2]
6+
du[3] = u[1] * u[2] - (8 / 3) * u[3]
7+
end
8+
9+
const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
10+
11+
function f(u0::Array{Float64})
12+
tspan = (0.0, 3.0)
13+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
14+
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
15+
sum(sol)
16+
end;
17+
u0 = [1.0; 0.0; 0.0]
18+
mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]
19+
mooncake_gradient(f, u0)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ end
156156
activate_downstream_env()
157157
@time @safetestset "DelayDiffEq Tests" include("downstream/delaydiffeq.jl")
158158
@time @safetestset "Measurements Tests" include("downstream/measurements.jl")
159+
@time @safetestset "Sparse Diff Tests" include("downstream/mooncake.jl")
159160
@time @safetestset "Sparse Diff Tests" include("downstream/sparsediff_tests.jl")
160161
@time @safetestset "Time derivative Tests" include("downstream/time_derivative_test.jl")
161162
end

0 commit comments

Comments
 (0)