Skip to content

Commit 032b927

Browse files
chore: try to avoid returning object
1 parent ff9bb2c commit 032b927

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,13 @@ end
129129
VA[sym], ODESolution_getindex_pullback
130130
end
131131

132-
function obs_grads(VA, sym, obss_idx, Δ)
132+
function obs_grads(VA, sym, obs_idx, Δ)
133133
y, back = Zygote.pullback(VA) do sol
134-
getindex.(Ref(sol), sym[obss_idx])
134+
getindex.(Ref(sol), sym[obs_idx])
135135
end
136-
Dprime = reduce(hcat, Δ)
137-
Dobss = eachrow(Dprime[obss_idx, :])
138-
back(Dobss)
136+
Δreduced = reduce(hcat, Δ)
137+
Δobs = eachrow(Δreduced[obs_idx, :])
138+
back(Δobs)
139139
end
140140

141141
function obs_grads(VA, sym, ::Nothing, Δ)
@@ -164,11 +164,11 @@ end
164164
sym = sym isa Tuple ? collect(sym) : sym
165165
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
166166

167-
obss_idx = findall(s -> is_observed(VA, s), sym)
168-
not_obss_idx = setdiff(1:length(sym), obss_idx)
167+
obs_idx = findall(s -> is_observed(VA, s), sym)
168+
not_obs_idx = setdiff(1:length(sym), obs_idx)
169169

170-
gs_obs = obs_grads(VA, sym, isempty(obss_idx) ? nothing : obss_idx, Δ)
171-
gs_not_obs = not_obs_grads(VA, sym, not_obss_idx, i, Δ)
170+
gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
171+
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)
172172

173173
a = Zygote.accum(gs_obs[1], gs_not_obs)
174174
(a, nothing)
@@ -220,7 +220,9 @@ end
220220
function solu_adjoint(Δ)
221221
zerou = zero(sol.prob.u0)
222222
= @. ifelse=== nothing, (zerou,), Δ)
223-
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
223+
nt = Zygote.nt_nothing(sol)
224+
gs = Zygote.accum(nt, (u = _Δ,))
225+
(gs,)
224226
end
225227
sol.u, solu_adjoint
226228
end

0 commit comments

Comments
 (0)