You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
y, back = Zygote.pullback(u, tunables) do u, tunables
121
+
f.(u, Ref(tunables), t)
122
+
end
123
+
gs =back(Δ)
124
+
(gs[1], nothing)
125
+
elseif i ===nothing
113
126
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
114
127
else
115
128
Δ′ = [[i == k ? Δ[j] :zero(x[1]) for k in1:length(x)]
@@ -120,26 +133,49 @@ end
120
133
VA[sym], ODESolution_getindex_pullback
121
134
end
122
135
136
+
functionobs_grads(VA, sym, obs_idx, Δ)
137
+
y, back = Zygote.pullback(VA) do sol
138
+
getindex.(Ref(sol), sym[obs_idx])
139
+
end
140
+
Δreduced =reduce(hcat, Δ)
141
+
Δobs =eachrow(Δreduced[obs_idx, :])
142
+
back(Δobs)
143
+
end
144
+
145
+
functionobs_grads(VA, sym, ::Nothing, Δ)
146
+
Zygote.nt_nothing(VA)
147
+
end
148
+
149
+
functionnot_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
150
+
Δ′ =map(enumerate(VA.u)) do (t_idx, us)
151
+
map(enumerate(us)) do (u_idx, u)
152
+
if u_idx in i
153
+
idx =findfirst(isequal(u_idx), i)
154
+
Δ[t_idx][idx]
155
+
else
156
+
zero(T)
157
+
end
158
+
end
159
+
end
160
+
161
+
Δ′
162
+
end
163
+
123
164
@adjointfunction Base.getindex(
124
165
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
125
166
functionODESolution_getindex_pullback(Δ)
126
167
sym = sym isa Tuple ?collect(sym) : sym
127
168
i =map(x ->symbolic_type(x) !=NotSymbolic() ?variable_index(VA, x) : x, sym)
128
-
if i ===nothing
129
-
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
0 commit comments