Skip to content

Commit e85bdb3

Browse files
Merge branch 'master' into optensemble
2 parents 1d87d19 + ffe68ae commit e85bdb3

20 files changed

+356
-150
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
julia-version: [1,1.6]
17+
julia-version: [1]
1818
os: [ubuntu-latest]
1919
package:
2020
- {user: SciML, repo: DelayDiffEq.jl, group: Interface}

Project.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com> and contributors"]
4-
version = "2.5.0"
4+
version = "2.7.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1110
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1211
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -32,16 +31,17 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3231
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3332
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3433
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
35-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3634

3735
[weakdeps]
36+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3837
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
3938
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
4039
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
4140
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
4241
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4342

4443
[extensions]
44+
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
4545
SciMLBasePartialFunctionsExt = "PartialFunctions"
4646
SciMLBasePyCallExt = "PyCall"
4747
SciMLBasePythonCallExt = "PythonCall"
@@ -54,14 +54,19 @@ ArrayInterface = "6, 7"
5454
ChainRulesCore = "1.16"
5555
CommonSolve = "0.2.4"
5656
ConstructionBase = "1"
57+
Distributed = "1.9"
5758
DocStringExtensions = "0.8, 0.9"
5859
EnumX = "1"
5960
FillArrays = "1.6"
6061
FunctionWrappersWrappers = "0.1.3"
6162
IteratorInterfaceExtensions = "^0.1, ^1"
63+
LinearAlgebra = "1.9"
64+
Logging = "1.9"
65+
Markdown = "1.9"
6266
PartialFunctions = "1.1"
6367
PrecompileTools = "1"
6468
Preferences = "1.3"
69+
Printf = "1.9"
6570
RCall = "0.13.18"
6671
RecipesBase = "0.7.0, 0.8, 1.0"
6772
RecursiveArrayTools = "2.33"
@@ -75,11 +80,12 @@ SymbolicIndexingInterface = "0.2"
7580
Tables = "1"
7681
TruncatedStacktraces = "1"
7782
QuasiMonteCarlo = "0.3"
78-
ZygoteRules = "0.2"
79-
julia = "1.6"
83+
Zygote = "0.6"
84+
julia = "1.9"
8085

8186
[extras]
8287
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
88+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8389
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
8490
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8591
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"

src/solutions/chainrules.jl renamed to ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
module SciMLBaseChainRulesCoreExt
2+
3+
using SciMLBase
4+
import ChainRulesCore
5+
import ChainRulesCore: NoTangent, @non_differentiable
6+
17
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
28
>:ChainRulesCore.HasReverseMode,
39
},
@@ -70,3 +76,60 @@ function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
7076
end
7177
VA[sym], ODESolution_getindex_pullback
7278
end
79+
80+
function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
81+
function ODEProblemAdjoint(ȳ)
82+
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
83+
end
84+
85+
ODEProblem(args...; kwargs...), ODEProblemAdjoint
86+
end
87+
88+
function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
89+
function SDEProblemAdjoint(ȳ)
90+
(NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
91+
end
92+
93+
SDEProblem(args...; kwargs...), SDEProblemAdjoint
94+
end
95+
96+
function ChainRulesCore.rrule(::Type{
97+
<:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
98+
T11, T12,
99+
}}, u,
100+
args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11,
101+
T12}
102+
function ODESolutionAdjoint(ȳ)
103+
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
104+
end
105+
106+
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
107+
ODESolutionAdjoint
108+
end
109+
110+
function ChainRulesCore.rrule(::Type{
111+
<:ODESolution{uType, tType, isinplace, P, NP, F, G, K,
112+
ND,
113+
}}, u,
114+
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
115+
function SDESolutionAdjoint(ȳ)
116+
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
117+
end
118+
119+
SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
120+
end
121+
122+
function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged)
123+
out = EnsembleSolution(sim, time, converged)
124+
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
125+
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
126+
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
127+
(NoTangent(), EnsembleSolution(arrarr, 0.0, true), NoTangent(), NoTangent())
128+
end
129+
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
130+
(NoTangent(), p̄, NoTangent(), NoTangent())
131+
end
132+
out, EnsembleSolution_adjoint
133+
end
134+
135+
end

ext/SciMLBaseZygoteExt.jl

Lines changed: 152 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module SciMLBaseZygoteExt
22

3-
using Zygote: pullback
4-
using ZygoteRules: @adjoint
5-
import ZygoteRules
6-
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved
3+
using Zygote
4+
using Zygote: @adjoint, pullback
5+
import Zygote: literal_getproperty
6+
using SciMLBase
7+
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
8+
getobserved, build_solution, EnsembleSolution,
9+
NonlinearSolution, AbstractTimeseriesSolution
710

811
# This method resolves the ambiguity with the pullback defined in
912
# RecursiveArrayToolsZygoteExt
@@ -56,25 +59,164 @@ end
5659
VA[sym, j], ODESolution_getindex_pullback
5760
end
5861

59-
ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats)
60-
out = EnsembleSolution(sim, time, converged)
62+
@adjoint function EnsembleSolution(sim, time, converged, stats)
63+
out = EnsembleSolution(sim, time, converged, stats)
6164
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
6265
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
6366
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
64-
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing)
67+
(EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing)
6568
end
6669
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
67-
(EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing)
70+
(EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing)
6871
end
6972
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
7073
(p̄, nothing, nothing, nothing)
7174
end
7275
out, EnsembleSolution_adjoint
7376
end
7477

75-
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
78+
@adjoint function getindex(VA::ODESolution, i::Int)
79+
function ODESolution_getindex_pullback(Δ)
80+
Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x)))
81+
for (x, j) in zip(VA.u, 1:length(VA))]
82+
(Δ′, nothing)
83+
end
84+
VA[i], ODESolution_getindex_pullback
85+
end
86+
87+
@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
7688
::Val{:u})
77-
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)
89+
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
90+
end
91+
92+
@adjoint function getindex(VA::ODESolution, sym)
93+
function ODESolution_getindex_pullback(Δ)
94+
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
95+
if i === nothing
96+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
97+
else
98+
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
99+
for (x, j) in zip(VA.u, 1:length(VA))]
100+
(Δ′, nothing)
101+
end
102+
end
103+
VA[sym], ODESolution_getindex_pullback
104+
end
105+
106+
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
107+
}(u,
108+
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
109+
T9, T10, T11, T12}
110+
function ODESolutionAdjoint(ȳ)
111+
(ȳ, ntuple(_ -> nothing, length(args))...)
112+
end
113+
114+
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
115+
ODESolutionAdjoint
116+
end
117+
118+
@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u,
119+
args...) where
120+
{uType, tType, isinplace, P, NP, F, G, K, ND}
121+
function SDESolutionAdjoint(ȳ)
122+
(ȳ, ntuple(_ -> nothing, length(args))...)
123+
end
124+
125+
SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
126+
end
127+
128+
@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,
129+
args...) where {
130+
T,
131+
N,
132+
uType,
133+
R,
134+
P,
135+
A,
136+
O,
137+
uType2,
138+
}
139+
function NonlinearSolutionAdjoint(ȳ)
140+
(ȳ, ntuple(_ -> nothing, length(args))...)
141+
end
142+
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
143+
end
144+
145+
@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
146+
::Val{:u})
147+
function solu_adjoint(Δ)
148+
zerou = zero(sol.prob.u0)
149+
= @. ifelse=== nothing, (zerou,), Δ)
150+
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
151+
end
152+
sol.u, solu_adjoint
153+
end
154+
155+
@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution,
156+
::Val{:u})
157+
function solu_adjoint(Δ)
158+
zerou = zero(sol.prob.u0)
159+
= @. ifelse=== nothing, zerou, Δ)
160+
(build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
161+
end
162+
sol.u, solu_adjoint
163+
end
164+
165+
@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution,
166+
::Val{:u})
167+
function solu_adjoint(Δ)
168+
zerou = zero(sol.u)
169+
= @. ifelse=== nothing, zerou, Δ)
170+
(build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
171+
end
172+
sol.u, solu_adjoint
173+
end
174+
175+
function ∇tmap(cx, f, args...)
176+
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
177+
if isempty(ys_and_backs)
178+
ys_and_backs, _ -> (NoTangent(), NoTangent())
179+
else
180+
ys, backs = Zygote.unzip(ys_and_backs)
181+
function ∇tmap_internal(Δ)
182+
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
183+
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
184+
Δf = reduce(Zygote.accum, Δf_and_args[1])
185+
(Δf, Δf_and_args[2:end]...)
186+
end
187+
ys, ∇tmap_internal
188+
end
189+
end
190+
191+
function ∇responsible_map(cx, f, args...)
192+
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
193+
args...)
194+
if isempty(ys_and_backs)
195+
ys_and_backs, _ -> (NoTangent(), NoTangent())
196+
else
197+
ys, backs = Zygote.unzip(ys_and_backs)
198+
ys,
199+
function ∇responsible_map_internal(Δ)
200+
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
201+
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
202+
Zygote._tryreverse(SciMLBase.responsible_map,
203+
backs, Δ)...)
204+
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
205+
Δf_and_args_zipped))
206+
Δf = reduce(Zygote.accum, Δf_and_args[1])
207+
(Δf, Δf_and_args[2:end]...)
208+
end
209+
end
210+
end
211+
212+
@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
213+
∇tmap(__context__, f, args...)
214+
end
215+
216+
@adjoint function SciMLBase.responsible_map(f,
217+
args::Union{AbstractArray, Tuple
218+
}...)
219+
∇responsible_map(__context__, f, args...)
78220
end
79221

80222
end

src/SciMLBase.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import RuntimeGeneratedFunctions
2222
import EnumX
2323
import TruncatedStacktraces
2424
import ADTypes: AbstractADType
25-
import ChainRulesCore
26-
import ZygoteRules: @adjoint
2725
import FillArrays
2826
import QuasiMonteCarlo
2927
using Reexport
@@ -716,7 +714,6 @@ include("solutions/optimization_solutions.jl")
716714
include("solutions/dae_solutions.jl")
717715
include("solutions/pde_solutions.jl")
718716
include("solutions/solution_interface.jl")
719-
include("solutions/zygote.jl")
720717

721718
include("ensemble/ensemble_solutions.jl")
722719
include("ensemble/ensemble_problems.jl")

src/ensemble/basic_ensemble_solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ function batch_func(i, prob, alg; kwargs...)
105105
new_prob = prob.prob_func(_prob, i, iter)
106106
rerun = true
107107
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
108-
if !(typeof(x) <: Tuple)
108+
if !(x isa Tuple)
109109
rerun_warn()
110110
_x = (x, false)
111111
else
@@ -117,7 +117,7 @@ function batch_func(i, prob, alg; kwargs...)
117117
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
118118
new_prob = prob.prob_func(_prob, i, iter)
119119
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
120-
if !(typeof(x) <: Tuple)
120+
if !(x isa Tuple)
121121
rerun_warn()
122122
_x = (x, false)
123123
else
@@ -170,7 +170,7 @@ function solve_batch(prob, alg, ensemblealg::EnsembleThreads, II, pmap_batch_siz
170170
return solve_batch(prob, alg, EnsembleSerial(), II, pmap_batch_size; kwargs...)
171171
end
172172

173-
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
173+
if prob.prob isa AbstractJumpProblem && length(II) != 1
174174
probs = [deepcopy(prob.prob) for i in 1:nthreads]
175175
else
176176
probs = prob.prob

0 commit comments

Comments
 (0)