|
1 | 1 | module SciMLBaseZygoteExt
|
2 | 2 |
|
3 | 3 | using Zygote
|
4 |
| -using Zygote: pullback, ZygoteRules |
5 |
| -using ZygoteRules: @adjoint |
| 4 | +using Zygote: @adjoint, pullback |
| 5 | +import Zygote: literal_getproperty |
6 | 6 | using SciMLBase
|
7 |
| -using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved |
| 7 | +using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, |
| 8 | + getobserved, build_solution, EnsembleSolution, |
| 9 | + NonlinearSolution, AbstractTimeseriesSolution |
8 | 10 |
|
9 | 11 | # This method resolves the ambiguity with the pullback defined in
|
10 | 12 | # RecursiveArrayToolsZygoteExt
|
|
82 | 84 | VA[i], ODESolution_getindex_pullback
|
83 | 85 | end
|
84 | 86 |
|
85 |
| -@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, |
| 87 | +@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution, |
86 | 88 | ::Val{:u})
|
87 | 89 | sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
|
88 | 90 | end
|
@@ -140,32 +142,32 @@ end
|
140 | 142 | NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
|
141 | 143 | end
|
142 | 144 |
|
143 |
| -@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution, |
| 145 | +@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, |
144 | 146 | ::Val{:u})
|
145 | 147 | function solu_adjoint(Δ)
|
146 | 148 | zerou = zero(sol.prob.u0)
|
147 | 149 | _Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
|
148 |
| - (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) |
| 150 | + (build_solution(sol.prob, sol.alg, sol.t, _Δ),) |
149 | 151 | end
|
150 | 152 | sol.u, solu_adjoint
|
151 | 153 | end
|
152 | 154 |
|
153 |
| -@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution, |
| 155 | +@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution, |
154 | 156 | ::Val{:u})
|
155 | 157 | function solu_adjoint(Δ)
|
156 | 158 | zerou = zero(sol.prob.u0)
|
157 | 159 | _Δ = @. ifelse(Δ === nothing, zerou, Δ)
|
158 |
| - (DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),) |
| 160 | + (build_solution(sol.prob, sol.alg, _Δ, sol.resid),) |
159 | 161 | end
|
160 | 162 | sol.u, solu_adjoint
|
161 | 163 | end
|
162 | 164 |
|
163 |
| -@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution, |
| 165 | +@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution, |
164 | 166 | ::Val{:u})
|
165 | 167 | function solu_adjoint(Δ)
|
166 | 168 | zerou = zero(sol.u)
|
167 | 169 | _Δ = @. ifelse(Δ === nothing, zerou, Δ)
|
168 |
| - (DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),) |
| 170 | + (build_solution(sol.cache, sol.alg, _Δ, sol.objective),) |
169 | 171 | end
|
170 | 172 | sol.u, solu_adjoint
|
171 | 173 | end
|
|
0 commit comments