@@ -6,9 +6,10 @@ import Zygote: literal_getproperty
6
6
using SciMLBase
7
7
using SciMLBase: ODESolution, remake,
8
8
getobserved, build_solution, EnsembleSolution,
9
- NonlinearSolution, AbstractTimeseriesSolution
9
+ NonlinearSolution, AbstractTimeseriesSolution,
10
+ SciMLStructures
10
11
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
11
- observed, parameter_values
12
+ observed, parameter_values, state_values, current_time
12
13
using RecursiveArrayTools
13
14
14
15
# This method resolves the ambiguity with the pullback defined in
@@ -111,10 +112,13 @@ end
111
112
function ODESolution_getindex_pullback (Δ)
112
113
i = symbolic_type (sym) != NotSymbolic () ? variable_index (VA, sym) : sym
113
114
if is_observed (VA, sym)
114
- y, back = Zygote. pullback (VA) do sol
115
- f = observed (sol, sym)
116
- p = parameter_values (sol)
117
- f .(sol. u, Ref (p), sol. t)
115
+ f = observed (VA, sym)
116
+ p = parameter_values (VA)
117
+ tunables, _, _ = SciMLStructures. canonicalize (SciMLStructures. Tunable (), p)
118
+ u = state_values (VA)
119
+ t = current_time (VA)
120
+ y, back = Zygote. pullback (u, tunables) do u, tunables
121
+ f .(u, Ref (tunables), t)
118
122
end
119
123
gs = back (Δ)
120
124
(gs[1 ], nothing )
@@ -154,8 +158,7 @@ function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
154
158
end
155
159
end
156
160
157
- nt = Zygote. nt_nothing (VA)
158
- Zygote. accum (nt, (u = Δ′,))
161
+ Δ′
159
162
end
160
163
161
164
@adjoint function Base. getindex (
171
174
gs_not_obs = not_obs_grads (VA, sym, not_obs_idx, i, Δ)
172
175
173
176
a = Zygote. accum (gs_obs[1 ], gs_not_obs)
177
+
174
178
(a, nothing )
175
179
end
176
180
VA[sym], ODESolution_getindex_pullback
0 commit comments