Skip to content

Commit 8958aea

Browse files
Merge branch 'master' into dg/cr
2 parents fd44bd4 + 9b4f6c7 commit 8958aea

12 files changed

+232
-38
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com> and contributors"]
4-
version = "2.35.0"
4+
version = "2.36.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -65,7 +65,7 @@ FunctionWrappersWrappers = "0.1.3"
6565
IteratorInterfaceExtensions = "^1"
6666
LinearAlgebra = "1.10"
6767
Logging = "1.10"
68-
Makie = "0.20"
68+
Makie = "0.20, 0.21"
6969
Markdown = "1.10"
7070
ModelingToolkit = "8.75, 9"
7171
PartialFunctions = "1.1"
@@ -76,15 +76,15 @@ PyCall = "1.96"
7676
PythonCall = "0.9.15"
7777
RCall = "0.14.0"
7878
RecipesBase = "1.3.4"
79-
RecursiveArrayTools = "3.8.0"
79+
RecursiveArrayTools = "3.14.0"
8080
Reexport = "1"
8181
RuntimeGeneratedFunctions = "0.5.12"
8282
SciMLOperators = "0.3.7"
8383
SciMLStructures = "1.1"
8484
StaticArrays = "1.7"
8585
StaticArraysCore = "1.4"
8686
Statistics = "1.10"
87-
SymbolicIndexingInterface = "0.3.15"
87+
SymbolicIndexingInterface = "0.3.20"
8888
Tables = "1.11"
8989
Zygote = "0.6.67"
9090
julia = "1.10"

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SciMLBaseChainRulesCoreExt
33
using SciMLBase
44
import ChainRulesCore
55
import ChainRulesCore: NoTangent, @non_differentiable
6+
using SymbolicIndexingInterface
67

78
function ChainRulesCore.rrule(
89
config::ChainRulesCore.RuleConfig{
@@ -13,7 +14,7 @@ function ChainRulesCore.rrule(
1314
sym,
1415
j::Integer)
1516
function ODESolution_getindex_pullback(Δ)
16-
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
17+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
1718
if i === nothing
1819
getter = getobserved(VA)
1920
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
@@ -66,7 +67,7 @@ end
6667

6768
function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
6869
function ODESolution_getindex_pullback(Δ)
69-
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
70+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
7071
if i === nothing
7172
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
7273
else
@@ -109,18 +110,19 @@ function ChainRulesCore.rrule(
109110
ODESolutionAdjoint
110111
end
111112

112-
# function ChainRulesCore.rrule(
113-
# ::Type{
114-
# <:SDESolution{uType, tType, isinplace, P, NP, F, G, K,
115-
# ND
116-
# }}, u,
117-
# args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
118-
# function SDESolutionAdjoint(ȳ)
119-
# (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
120-
# end
121-
#
122-
# SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
123-
# end
113+
function ChainRulesCore.rrule(
114+
::Type{
115+
<:RODESolution{uType, tType, isinplace, P, NP, F, G, K,
116+
ND
117+
}}, u,
118+
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
119+
function RODESolutionAdjoint(ȳ)
120+
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
121+
end
122+
123+
RODESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...),
124+
RODESolutionAdjoint
125+
end
124126

125127
function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged)
126128
out = EnsembleSolution(sim, time, converged)

ext/SciMLBaseZygoteExt.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using RecursiveArrayTools
1313
# This method resolves the ambiguity with the pullback defined in
1414
# RecursiveArrayToolsZygoteExt
1515
# https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67
16-
@adjoint function getindex(VA::ODESolution, i::Int, j::Int)
16+
@adjoint function Base.getindex(VA::ODESolution, i::Int, j::Int)
1717
function ODESolution_getindex_pullback(Δ)
1818
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
1919
zero(VA.u[1]) for m in 1:length(VA.u)]
@@ -38,7 +38,7 @@ using RecursiveArrayTools
3838
VA[i, j], ODESolution_getindex_pullback
3939
end
4040

41-
@adjoint function getindex(VA::ODESolution, sym, j::Int)
41+
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
4242
function ODESolution_getindex_pullback(Δ)
4343
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
4444
du, dprob = if i === nothing
@@ -92,7 +92,7 @@ end
9292
out, EnsembleSolution_adjoint
9393
end
9494

95-
@adjoint function getindex(VA::ODESolution, i::Int)
95+
@adjoint function Base.getindex(VA::ODESolution, i::Int)
9696
function ODESolution_getindex_pullback(Δ)
9797
Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x)))
9898
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -106,7 +106,7 @@ end
106106
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
107107
end
108108

109-
@adjoint function getindex(VA::ODESolution, sym)
109+
@adjoint function Base.getindex(VA::ODESolution, sym)
110110
function ODESolution_getindex_pullback(Δ)
111111
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
112112
if i === nothing
@@ -120,6 +120,30 @@ end
120120
VA[sym], ODESolution_getindex_pullback
121121
end
122122

123+
@adjoint function Base.getindex(
124+
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
125+
function ODESolution_getindex_pullback(Δ)
126+
sym = sym isa Tuple ? collect(sym) : sym
127+
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
128+
if i === nothing
129+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
130+
else
131+
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
132+
map(enumerate(us)) do (u_idx, u)
133+
if u_idx in i
134+
idx = findfirst(isequal(u_idx), i)
135+
Δ[t_idx][idx]
136+
else
137+
zero(T)
138+
end
139+
end
140+
end
141+
(Δ′, nothing)
142+
end
143+
end
144+
VA[sym], ODESolution_getindex_pullback
145+
end
146+
123147
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
124148
}(u,
125149
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
@@ -135,11 +159,11 @@ end
135159
@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u,
136160
args...) where
137161
{uType, tType, isinplace, P, NP, F, G, K, ND}
138-
function SDESolutionAdjoint(ȳ)
162+
function SDEProblemAdjoint(ȳ)
139163
(ȳ, ntuple(_ -> nothing, length(args))...)
140164
end
141165

142-
SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
166+
SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDEProblemAdjoint
143167
end
144168

145169
@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,

src/ensemble/ensemble_solutions.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ end
211211
end
212212
end
213213

214-
Base.@propagate_inbounds function RecursiveArrayTools._getindex(
215-
x::AbstractEnsembleSolution, ::Union{ScalarSymbolic, ArraySymbolic}, s, ::Colon)
214+
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
216215
return [xi[s] for xi in x.u]
217216
end
218217

src/integrator_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,8 +643,8 @@ Same as `check_error` but also set solution's return code
643643
"""
644644
function check_error!(integrator::DEIntegrator)
645645
code = check_error(integrator)
646+
integrator.sol = solution_new_retcode(integrator.sol, code)
646647
if code != ReturnCode.Success
647-
integrator.sol = solution_new_retcode(integrator.sol, code)
648648
postamble!(integrator)
649649
end
650650
return code

src/solutions/ode_solutions.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,18 @@ function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense
146146
dense, tslocation, stats, alg_choice, retcode, resid, original)
147147
end
148148

149+
error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
150+
function error_if_observed_derivative(sys, idx, ::Type)
151+
if symbolic_type(idx) != NotSymbolic() && is_observed(sys, idx) ||
152+
symbolic_type(idx) == NotSymbolic() && any(x -> is_observed(sys, x), idx)
153+
error("""
154+
Cannot interpolate derivatives of observed variables. A possible solution could be
155+
interpolating the symbolic expression that evaluates to the derivative of the
156+
observed variable or using DataInterpolations.jl.
157+
""")
158+
end
159+
end
160+
149161
function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
150162
continuity = :left) where {deriv}
151163
sol(t, deriv, idxs, continuity)
@@ -172,6 +184,9 @@ end
172184
function (sol::AbstractODESolution)(t::Number, ::Type{deriv},
173185
idxs::AbstractVector{<:Integer},
174186
continuity) where {deriv}
187+
if eltype(sol.u) <: Number
188+
idxs = only(idxs)
189+
end
175190
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
176191
end
177192
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
@@ -183,6 +198,9 @@ end
183198
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
184199
idxs::AbstractVector{<:Integer},
185200
continuity) where {deriv}
201+
if eltype(sol.u) <: Number
202+
idxs = only(idxs)
203+
end
186204
A = sol.interp(t, idxs, deriv, sol.prob.p, continuity)
187205
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
188206
return DiffEqArray(A.u, A.t, p, sol)
@@ -191,6 +209,7 @@ end
191209
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
192210
continuity) where {deriv}
193211
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
212+
error_if_observed_derivative(sol, idxs, deriv)
194213
if is_parameter(sol, idxs)
195214
return getp(sol, idxs)(sol)
196215
else
@@ -200,15 +219,19 @@ end
200219

201220
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
202221
continuity) where {deriv}
203-
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
222+
if symbolic_type(idxs) == NotSymbolic() &&
223+
any(isequal(NotSymbolic()), symbolic_type.(idxs))
204224
error("Incorrect specification of `idxs`")
225+
end
226+
error_if_observed_derivative(sol, idxs, deriv)
205227
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
206-
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]
228+
first(interp_sol[idxs])
207229
end
208230

209231
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
210232
continuity) where {deriv}
211233
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
234+
error_if_observed_derivative(sol, idxs, deriv)
212235
if is_parameter(sol, idxs)
213236
return getp(sol, idxs)(sol)
214237
else
@@ -222,10 +245,12 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
222245
idxs::AbstractVector, continuity) where {deriv}
223246
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
224247
error("Incorrect specification of `idxs`")
248+
error_if_observed_derivative(sol, idxs, deriv)
225249
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
226250
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
251+
indexed_sol = interp_sol[idxs]
227252
return DiffEqArray(
228-
[[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol)
253+
[indexed_sol[i] for i in 1:length(t)], t, p, sol)
229254
end
230255

231256
function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},

src/solutions/solution_interface.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,18 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
436436
plot_vecs = []
437437
labels = String[]
438438
varsyms = variable_symbols(sol)
439+
batch_symbolic_vars = []
440+
for x in vars
441+
for j in 2:length(x)
442+
if (x[j] isa Integer && x[j] == 0) || isequal(x[j], getindepsym_defaultt(sol))
443+
else
444+
push!(batch_symbolic_vars, x[j])
445+
end
446+
end
447+
end
448+
batch_symbolic_vars = identity.(batch_symbolic_vars)
449+
indexed_solution = sol(plott; idxs = batch_symbolic_vars)
450+
idxx = 0
439451
for x in vars
440452
tmp = []
441453
strs = String[]
@@ -444,7 +456,8 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
444456
push!(tmp, plott)
445457
push!(strs, "t")
446458
else
447-
push!(tmp, sol(plott; idxs = x[j]))
459+
idxx += 1
460+
push!(tmp, indexed_solution[idxx, :])
448461
if !isempty(varsyms) && x[j] isa Integer
449462
push!(strs, String(getname(varsyms[x[j]])))
450463
elseif hasname(x[j])

test/downstream/adjoints.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Zygote, Test
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
4+
@parameters σ ρ β
5+
@variables x(t) y(t) z(t)
6+
7+
eqs = [D(x) ~ σ * (y - x),
8+
D(y) ~ x *- z) - y,
9+
D(z) ~ x * y - β * z]
10+
11+
@named lorenz1 = ODESystem(eqs, t)
12+
@named lorenz2 = ODESystem(eqs, t)
13+
14+
@parameters γ
15+
@variables a(t) α(t)
16+
connections = [0 ~ lorenz1.x + lorenz2.y + a * γ,
17+
α ~ 2lorenz1.x + a * γ]
18+
@mtkbuild sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])
19+
20+
u0 = [lorenz1.x => 1.0,
21+
lorenz1.y => 0.0,
22+
lorenz1.z => 0.0,
23+
lorenz2.x => 0.0,
24+
lorenz2.y => 1.0,
25+
lorenz2.z => 0.0,
26+
a => 2.0]
27+
28+
p = [lorenz1.σ => 10.0,
29+
lorenz1.ρ => 28.0,
30+
lorenz1.β => 8 / 3,
31+
lorenz2.σ => 10.0,
32+
lorenz2.ρ => 28.0,
33+
lorenz2.β => 8 / 3,
34+
γ => 2.0]
35+
36+
tspan = (0.0, 100.0)
37+
prob = ODEProblem(sys, u0, tspan, p)
38+
sol = solve(prob, Rodas4())
39+
40+
gs_sym, = Zygote.gradient(sol) do sol
41+
sum(sol[lorenz1.x])
42+
end
43+
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
44+
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
45+
true_grad_sym[idx_sym] = 1.0
46+
47+
@test all(map(x -> x == true_grad_sym, gs_sym))
48+
49+
gs_vec, = Zygote.gradient(sol) do sol
50+
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
51+
end
52+
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
53+
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
54+
true_grad_vecsym[idx_vecsym] .= 1.0
55+
56+
@test all(map(x -> x == true_grad_vecsym, gs_vec))
57+
58+
gs_tup, = Zygote.gradient(sol) do sol
59+
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
60+
end
61+
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
62+
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
63+
true_grad_tupsym[idx_tupsym] .= 1.0
64+
65+
@test all(map(x -> x == true_grad_tupsym, gs_tup))
66+
67+
gs_ts, = Zygote.gradient(sol) do sol
68+
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
69+
end
70+
71+
@test all(map(x -> x == true_grad_vecsym, gs_ts))
72+
73+
# BatchedInterface AD
74+
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0
75+
@named sys1 = ODESystem([D(x) ~ x + y, D(y) ~ y * z, D(z) ~ z * t * x], t)
76+
sys1 = complete(sys1)
77+
prob1 = ODEProblem(sys1, [], (0.0, 10.0))
78+
@named sys2 = ODESystem([D(x) ~ x + w, D(y) ~ w * t, D(w) ~ x + y + w], t)
79+
sys2 = complete(sys2)
80+
prob2 = ODEProblem(sys2, [], (0.0, 10.0))
81+
82+
bi = BatchedInterface((sys1, [x, y, z]), (sys2, [x, y, w]))
83+
getter = getu(bi)
84+
85+
p1grad, p2grad = Zygote.gradient(prob1, prob2) do prob1, prob2
86+
sum(getter(prob1, prob2))
87+
end
88+
89+
@test p1grad.u0 ones(3)
90+
testp2grad = zeros(3)
91+
testp2grad[variable_index(prob2, w)] = 1.0
92+
@test p2grad.u0 testp2grad

0 commit comments

Comments
 (0)