Skip to content

Commit ee4c31d

Browse files
Merge pull request #2742 from SciML/mooncake_direct
Mooncake Direct Adjoints for Explicit Integrators
2 parents 55e359e + 96a12f9 commit ee4c31d

File tree

7 files changed

+51
-6
lines changed

7 files changed

+51
-6
lines changed

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
3939

4040
[weakdeps]
4141
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
42+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
4243

4344
[extensions]
4445
OrdinaryDiffEqCoreEnzymeCoreExt = "EnzymeCore"
46+
OrdinaryDiffEqCoreMooncakeExt = "Mooncake"
4547

4648
[compat]
4749
ADTypes = "0.2, 1"
@@ -63,6 +65,7 @@ InteractiveUtils = "1.9"
6365
LinearAlgebra = "1.9"
6466
Logging = "1.9"
6567
MacroTools = "0.5"
68+
Mooncake = "0.4"
6669
MuladdMacro = "0.2.1"
6770
Polyester = "0.7"
6871
PrecompileTools = "1"

lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,9 @@ function EnzymeCore.EnzymeRules.inactive_noinl(
2626
true
2727
end
2828

29+
function EnzymeCore.EnzymeRules.inactive_noinl(
30+
::typeof(OrdinaryDiffEqCore.final_progress), args...)
31+
true
32+
end
33+
2934
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module OrdinaryDiffEqCoreMooncakeExt
2+
3+
using OrdinaryDiffEqCore, Mooncake
4+
using Mooncake: @zero_adjoint, MinimalCtx
5+
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(OrdinaryDiffEqCore.ode_determine_initdt), Any, Any, Any, Any, Any, Any, Any, Any, Any}
6+
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(OrdinaryDiffEqCore.SciMLBase.check_error), Any}
7+
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), Any, Any}
8+
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(OrdinaryDiffEqCore.final_progress), Any}
9+
10+
end

lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,19 @@ function _postamble!(integrator)
195195
resize!(integrator.sol.k, integrator.saveiter_dense)
196196
end
197197
if integrator.opts.progress
198-
@logmsg(LogLevel(-1),
199-
integrator.opts.progress_name,
200-
_id=integrator.opts.progress_id,
201-
message=integrator.opts.progress_message(integrator.dt, integrator.u,
202-
integrator.p, integrator.t),
203-
progress="done")
198+
204199
end
205200
end
206201

202+
function final_progress(integrator)
203+
@logmsg(LogLevel(-1),
204+
integrator.opts.progress_name,
205+
_id=integrator.opts.progress_id,
206+
message=integrator.opts.progress_message(integrator.dt, integrator.u,
207+
integrator.p, integrator.t),
208+
progress="done")
209+
end
210+
207211
function solution_endpoint_match_cur_integrator!(integrator)
208212
if integrator.opts.save_end &&
209213
(integrator.saveiter == 0 ||

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ DDEProblemLibrary = "f42792ee-6ffc-4e2a-ae83-8ee2f22de800"
33
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
44
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
55
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
6+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
67
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
78
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
89
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

test/downstream/mooncake.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using Mooncake, OrdinaryDiffEq, StaticArrays
2+
3+
function lorenz!(du, u, p, t)
4+
du[1] = 10.0(u[2] - u[1])
5+
du[2] = u[1] * (28.0 - u[3]) - u[2]
6+
du[3] = u[1] * u[2] - (8 / 3) * u[3]
7+
end
8+
9+
const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
10+
11+
function f(u0::Array{Float64})
12+
tspan = (0.0, 3.0)
13+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
14+
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
15+
sum(sol)
16+
end;
17+
u0 = [1.0; 0.0; 0.0]
18+
mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]
19+
mooncake_gradient(f, u0)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ end
156156
activate_downstream_env()
157157
@time @safetestset "DelayDiffEq Tests" include("downstream/delaydiffeq.jl")
158158
@time @safetestset "Measurements Tests" include("downstream/measurements.jl")
159+
if VERSION >= v"1.11"
160+
@time @safetestset "Mooncake Tests" include("downstream/mooncake.jl")
161+
end
159162
@time @safetestset "Sparse Diff Tests" include("downstream/sparsediff_tests.jl")
160163
@time @safetestset "Time derivative Tests" include("downstream/time_derivative_test.jl")
161164
end

0 commit comments

Comments
 (0)