@@ -129,13 +129,13 @@ end
129
129
VA[sym], ODESolution_getindex_pullback
130
130
end
131
131
132
- function obs_grads (VA, sym, obss_idx , Δ)
132
+ function obs_grads (VA, sym, obs_idx , Δ)
133
133
y, back = Zygote. pullback (VA) do sol
134
- getindex .(Ref (sol), sym[obss_idx ])
134
+ getindex .(Ref (sol), sym[obs_idx ])
135
135
end
136
- Dprime = reduce (hcat, Δ)
137
- Dobss = eachrow (Dprime[obss_idx , :])
138
- back (Dobss )
136
+ Δreduced = reduce (hcat, Δ)
137
+ Δobs = eachrow (Δreduced[obs_idx , :])
138
+ back (Δobs )
139
139
end
140
140
141
141
function obs_grads (VA, sym, :: Nothing , Δ)
@@ -164,11 +164,11 @@ end
164
164
sym = sym isa Tuple ? collect (sym) : sym
165
165
i = map (x -> symbolic_type (x) != NotSymbolic () ? variable_index (VA, x) : x, sym)
166
166
167
- obss_idx = findall (s -> is_observed (VA, s), sym)
168
- not_obss_idx = setdiff (1 : length (sym), obss_idx )
167
+ obs_idx = findall (s -> is_observed (VA, s), sym)
168
+ not_obs_idx = setdiff (1 : length (sym), obs_idx )
169
169
170
- gs_obs = obs_grads (VA, sym, isempty (obss_idx ) ? nothing : obss_idx , Δ)
171
- gs_not_obs = not_obs_grads (VA, sym, not_obss_idx , i, Δ)
170
+ gs_obs = obs_grads (VA, sym, isempty (obs_idx ) ? nothing : obs_idx , Δ)
171
+ gs_not_obs = not_obs_grads (VA, sym, not_obs_idx , i, Δ)
172
172
173
173
a = Zygote. accum (gs_obs[1 ], gs_not_obs)
174
174
(a, nothing )
220
220
function solu_adjoint (Δ)
221
221
zerou = zero (sol. prob. u0)
222
222
_Δ = @. ifelse (Δ === nothing , (zerou,), Δ)
223
- (build_solution (sol. prob, sol. alg, sol. t, _Δ),)
223
+ nt = Zygote. nt_nothing (sol)
224
+ gs = Zygote. accum (nt, (u = _Δ,))
225
+ (gs,)
224
226
end
225
227
sol. u, solu_adjoint
226
228
end
0 commit comments