Skip to content

Commit 3811745

Browse files
Merge pull request #689 from DhairyaLGandhi/dg/obsfn
Feat: adjoints through observable functions
2 parents 5b172dc + f817b52 commit 3811745

File tree

5 files changed

+162
-18
lines changed

5 files changed

+162
-18
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ jobs:
1414
test:
1515
runs-on: ubuntu-latest
1616
strategy:
17+
fail-fast: false
1718
matrix:
1819
group:
1920
- Core
@@ -47,4 +48,4 @@ jobs:
4748
with:
4849
file: lcov.info
4950
token: ${{ secrets.CODECOV_TOKEN }}
50-
fail_ci_if_error: true
51+
fail_ci_if_error: false

ext/SciMLBaseZygoteExt.jl

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ using SciMLBase
77
using SciMLBase: ODESolution, remake,
88
getobserved, build_solution, EnsembleSolution,
99
NonlinearSolution, AbstractTimeseriesSolution
10-
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index
10+
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
11+
observed, parameter_values, state_values, current_time
1112
using RecursiveArrayTools
13+
import SciMLStructures
1214

1315
# This method resolves the ambiguity with the pullback defined in
1416
# RecursiveArrayToolsZygoteExt
@@ -109,7 +111,18 @@ end
109111
@adjoint function Base.getindex(VA::ODESolution, sym)
110112
function ODESolution_getindex_pullback(Δ)
111113
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
112-
if i === nothing
114+
if is_observed(VA, sym)
115+
f = observed(VA, sym)
116+
p = parameter_values(VA)
117+
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
118+
u = state_values(VA)
119+
t = current_time(VA)
120+
y, back = Zygote.pullback(u, tunables) do u, tunables
121+
f.(u, Ref(tunables), t)
122+
end
123+
gs = back(Δ)
124+
(gs[1], nothing)
125+
elseif i === nothing
113126
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
114127
else
115128
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
@@ -120,26 +133,49 @@ end
120133
VA[sym], ODESolution_getindex_pullback
121134
end
122135

136+
function obs_grads(VA, sym, obs_idx, Δ)
137+
y, back = Zygote.pullback(VA) do sol
138+
getindex.(Ref(sol), sym[obs_idx])
139+
end
140+
Δreduced = reduce(hcat, Δ)
141+
Δobs = eachrow(Δreduced[obs_idx, :])
142+
back(Δobs)
143+
end
144+
145+
function obs_grads(VA, sym, ::Nothing, Δ)
146+
Zygote.nt_nothing(VA)
147+
end
148+
149+
function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
150+
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
151+
map(enumerate(us)) do (u_idx, u)
152+
if u_idx in i
153+
idx = findfirst(isequal(u_idx), i)
154+
Δ[t_idx][idx]
155+
else
156+
zero(T)
157+
end
158+
end
159+
end
160+
161+
Δ′
162+
end
163+
123164
@adjoint function Base.getindex(
124165
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
125166
function ODESolution_getindex_pullback(Δ)
126167
sym = sym isa Tuple ? collect(sym) : sym
127168
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
128-
if i === nothing
129-
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
130-
else
131-
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
132-
map(enumerate(us)) do (u_idx, u)
133-
if u_idx in i
134-
idx = findfirst(isequal(u_idx), i)
135-
Δ[t_idx][idx]
136-
else
137-
zero(T)
138-
end
139-
end
140-
end
141-
(Δ′, nothing)
142-
end
169+
170+
obs_idx = findall(s -> is_observed(VA, s), sym)
171+
not_obs_idx = setdiff(1:length(sym), obs_idx)
172+
173+
gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
174+
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)
175+
176+
a = Zygote.accum(gs_obs[1], gs_not_obs)
177+
178+
(a, nothing)
143179
end
144180
VA[sym], ODESolution_getindex_pullback
145181
end

test/downstream/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
33
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
44
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
55
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
6+
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
67
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
78
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
89
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
@@ -22,6 +23,7 @@ BoundaryValueDiffEq = "5"
2223
ForwardDiff = "0.10"
2324
JumpProcesses = "9.10"
2425
ModelingToolkit = "8.37, 9"
26+
ModelingToolkitStandardLibrary = "2.7"
2527
NonlinearSolve = "2, 3"
2628
Optimization = "3"
2729
OptimizationMOI = "0.4"
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using ModelingToolkit, OrdinaryDiffEq
2+
using Zygote
3+
using ModelingToolkit: t_nounits as t, D_nounits as D
4+
import SymbolicIndexingInterface as SII
5+
import SciMLStructures as SS
6+
using ModelingToolkitStandardLibrary
7+
import ModelingToolkitStandardLibrary as MSL
8+
9+
@parameters σ ρ β
10+
@variables x(t) y(t) z(t) w(t)
11+
12+
eqs = [D(D(x)) ~ σ * (y - x),
13+
D(y) ~ x *- z) - y,
14+
D(z) ~ x * y - β * z,
15+
w ~ x + y + z + 2 * β]
16+
17+
@mtkbuild sys = ODESystem(eqs, t)
18+
19+
u0 = [D(x) => 2.0,
20+
x => 1.0,
21+
y => 0.0,
22+
z => 0.0]
23+
24+
p ==> 28.0,
25+
ρ => 10.0,
26+
β => 8 / 3]
27+
28+
tspan = (0.0, 100.0)
29+
prob = ODEProblem(sys, u0, tspan, p, jac = true)
30+
sol = solve(prob, Tsit5())
31+
32+
@testset "AutoDiff Observable Functions" begin
33+
gs, = gradient(sol) do sol
34+
sum(sol[sys.w])
35+
end
36+
du_ = [0.0, 1.0, 1.0, 1.0]
37+
du = [du_ for _ in sol.u]
38+
@test du == gs
39+
40+
# Observable in a vector
41+
gs, = gradient(sol) do sol
42+
sum(sum.(sol[[sys.w, sys.x]]))
43+
end
44+
du_ = [0.0, 1.0, 1.0, 2.0]
45+
du = [du_ for _ in sol.u]
46+
@test du == gs
47+
end
48+
49+
# DAE
50+
51+
function create_model(; C₁ = 3e-5, C₂ = 1e-6)
52+
@variables t
53+
@named resistor1 = MSL.Electrical.Resistor(R = 5.0)
54+
@named resistor2 = MSL.Electrical.Resistor(R = 2.0)
55+
@named capacitor1 = MSL.Electrical.Capacitor(C = C₁)
56+
@named capacitor2 = MSL.Electrical.Capacitor(C = C₂)
57+
@named source = MSL.Electrical.Voltage()
58+
@named input_signal = MSL.Blocks.Sine(frequency = 100.0)
59+
@named ground = MSL.Electrical.Ground()
60+
@named ampermeter = MSL.Electrical.CurrentSensor()
61+
62+
eqs = [connect(input_signal.output, source.V)
63+
connect(source.p, capacitor1.n, capacitor2.n)
64+
connect(source.n, resistor1.p, resistor2.p, ground.g)
65+
connect(resistor1.n, capacitor1.p, ampermeter.n)
66+
connect(resistor2.n, capacitor2.p, ampermeter.p)]
67+
68+
@named circuit_model = ODESystem(eqs, t,
69+
systems = [
70+
resistor1, resistor2, capacitor1, capacitor2,
71+
source, input_signal, ground, ampermeter
72+
])
73+
end
74+
75+
@testset "DAE Observable function AD" begin
76+
model = create_model()
77+
sys = structural_simplify(model)
78+
79+
prob = ODEProblem(sys, [], (0.0, 1.0))
80+
sol = solve(prob, Rodas4())
81+
82+
gs, = gradient(sol) do sol
83+
sum(sol[sys.ampermeter.i])
84+
end
85+
du_ = [0.2, 1.0]
86+
du = [du_ for _ in sol.u]
87+
@test gs == du
88+
end
89+
90+
# @testset "Adjoints with DAE" begin
91+
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
92+
# new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables)
93+
# new_prob = remake(prob, p = new_p)
94+
# sol = solve(new_prob, Rodas4())
95+
# @show size(sol)
96+
# # mean(abs.(sol[sys.ampermeter.i] .- gt))
97+
# sum(sol[sys.ampermeter.i])
98+
# end
99+
#
100+
# @test isnothing(gs_mtkp)
101+
# @test length(gs_p_new) == length(p_new)
102+
# end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ end
110110
@time @safetestset "Partial Functions" begin
111111
include("downstream/partial_functions.jl")
112112
end
113+
@time @safetestset "Autodiff Observable Functions" begin
114+
include("downstream/observables_autodiff.jl")
115+
end
113116
end
114117

115118
if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface")

0 commit comments

Comments
 (0)