Skip to content

Commit de2d6cd

Browse files
chore: don't return structural tangent
1 parent 44bfc91 commit de2d6cd

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ import Zygote: literal_getproperty
66
using SciMLBase
77
using SciMLBase: ODESolution, remake,
88
getobserved, build_solution, EnsembleSolution,
9-
NonlinearSolution, AbstractTimeseriesSolution
9+
NonlinearSolution, AbstractTimeseriesSolution,
10+
SciMLStructures
1011
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
11-
observed, parameter_values
12+
observed, parameter_values, state_values, current_time
1213
using RecursiveArrayTools
1314

1415
# This method resolves the ambiguity with the pullback defined in
@@ -111,10 +112,13 @@ end
111112
function ODESolution_getindex_pullback(Δ)
112113
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
113114
if is_observed(VA, sym)
114-
y, back = Zygote.pullback(VA) do sol
115-
f = observed(sol, sym)
116-
p = parameter_values(sol)
117-
f.(sol.u, Ref(p), sol.t)
115+
f = observed(VA, sym)
116+
p = parameter_values(VA)
117+
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
118+
u = state_values(VA)
119+
t = current_time(VA)
120+
y, back = Zygote.pullback(u, tunables) do u, tunables
121+
f.(u, Ref(tunables), t)
118122
end
119123
gs = back(Δ)
120124
(gs[1], nothing)
@@ -154,8 +158,7 @@ function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
154158
end
155159
end
156160

157-
nt = Zygote.nt_nothing(VA)
158-
Zygote.accum(nt, (u = Δ′,))
161+
Δ′
159162
end
160163

161164
@adjoint function Base.getindex(
@@ -171,6 +174,7 @@ end
171174
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)
172175

173176
a = Zygote.accum(gs_obs[1], gs_not_obs)
177+
174178
(a, nothing)
175179
end
176180
VA[sym], ODESolution_getindex_pullback

test/downstream/observables_autodiff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ sol = solve(prob, Tsit5())
3535
end
3636
du_ = [0.0, 1.0, 1.0, 1.0]
3737
du = [du_ for _ in sol.u]
38-
@test du == gs.u
38+
@test du == gs
3939

4040
# Observable in a vector
4141
gs, = gradient(sol) do sol
4242
sum(sum.(sol[[sys.w, sys.x]]))
4343
end
4444
du_ = [0.0, 1.0, 1.0, 2.0]
4545
du = [du_ for _ in sol.u]
46-
@test du == gs.u
46+
@test du == gs
4747
end
4848

4949
# DAE
@@ -84,7 +84,7 @@ end
8484
end
8585
du_ = [0.2, 1.0]
8686
du = [du_ for _ in sol.u]
87-
@test gs.u == du
87+
@test gs == du
8888
end
8989

9090
# @testset "Adjoints with DAE" begin

0 commit comments

Comments
 (0)