|
1 | 1 | module SciMLBaseZygoteExt
|
2 | 2 |
|
3 |
| -using Zygote: pullback |
4 |
| -using ZygoteRules: @adjoint |
5 |
| -import ZygoteRules |
6 |
| -using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved |
| 3 | +using Zygote |
| 4 | +using Zygote: @adjoint, pullback |
| 5 | +import Zygote: literal_getproperty |
| 6 | +using SciMLBase |
| 7 | +using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, |
| 8 | + getobserved, build_solution, EnsembleSolution, |
| 9 | + NonlinearSolution, AbstractTimeseriesSolution |
7 | 10 |
|
8 | 11 | # This method resolves the ambiguity with the pullback defined in
|
9 | 12 | # RecursiveArrayToolsZygoteExt
|
|
56 | 59 | VA[sym, j], ODESolution_getindex_pullback
|
57 | 60 | end
|
58 | 61 |
|
59 |
| -ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats) |
60 |
| - out = EnsembleSolution(sim, time, converged) |
| 62 | +@adjoint function EnsembleSolution(sim, time, converged, stats) |
| 63 | + out = EnsembleSolution(sim, time, converged, stats) |
61 | 64 | function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
|
62 | 65 | arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
|
63 | 66 | for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
|
64 |
| - (EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing) |
| 67 | + (EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing) |
65 | 68 | end
|
66 | 69 | function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
|
67 |
| - (EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing) |
| 70 | + (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) |
68 | 71 | end
|
69 | 72 | function EnsembleSolution_adjoint(p̄::EnsembleSolution)
|
70 | 73 | (p̄, nothing, nothing, nothing)
|
71 | 74 | end
|
72 | 75 | out, EnsembleSolution_adjoint
|
73 | 76 | end
|
74 | 77 |
|
75 |
| -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, |
| 78 | +@adjoint function getindex(VA::ODESolution, i::Int) |
| 79 | + function ODESolution_getindex_pullback(Δ) |
| 80 | + Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x))) |
| 81 | + for (x, j) in zip(VA.u, 1:length(VA))] |
| 82 | + (Δ′, nothing) |
| 83 | + end |
| 84 | + VA[i], ODESolution_getindex_pullback |
| 85 | +end |
| 86 | + |
| 87 | +@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution, |
76 | 88 | ::Val{:u})
|
77 |
| - sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),) |
| 89 | + sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) |
| 90 | +end |
| 91 | + |
| 92 | +@adjoint function getindex(VA::ODESolution, sym) |
| 93 | + function ODESolution_getindex_pullback(Δ) |
| 94 | + i = issymbollike(sym) ? sym_to_index(sym, VA) : sym |
| 95 | + if i === nothing |
| 96 | + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) |
| 97 | + else |
| 98 | + Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] |
| 99 | + for (x, j) in zip(VA.u, 1:length(VA))] |
| 100 | + (Δ′, nothing) |
| 101 | + end |
| 102 | + end |
| 103 | + VA[sym], ODESolution_getindex_pullback |
| 104 | +end |
| 105 | + |
| 106 | +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 |
| 107 | + }(u, |
| 108 | + args...) where {T1, T2, T3, T4, T5, T6, T7, T8, |
| 109 | + T9, T10, T11, T12} |
| 110 | + function ODESolutionAdjoint(ȳ) |
| 111 | + (ȳ, ntuple(_ -> nothing, length(args))...) |
| 112 | + end |
| 113 | + |
| 114 | + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), |
| 115 | + ODESolutionAdjoint |
| 116 | +end |
| 117 | + |
| 118 | +@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, |
| 119 | + args...) where |
| 120 | + {uType, tType, isinplace, P, NP, F, G, K, ND} |
| 121 | + function SDESolutionAdjoint(ȳ) |
| 122 | + (ȳ, ntuple(_ -> nothing, length(args))...) |
| 123 | + end |
| 124 | + |
| 125 | + SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint |
| 126 | +end |
| 127 | + |
| 128 | +@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, |
| 129 | + args...) where { |
| 130 | + T, |
| 131 | + N, |
| 132 | + uType, |
| 133 | + R, |
| 134 | + P, |
| 135 | + A, |
| 136 | + O, |
| 137 | + uType2, |
| 138 | +} |
| 139 | + function NonlinearSolutionAdjoint(ȳ) |
| 140 | + (ȳ, ntuple(_ -> nothing, length(args))...) |
| 141 | + end |
| 142 | + NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint |
| 143 | +end |
| 144 | + |
| 145 | +@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, |
| 146 | + ::Val{:u}) |
| 147 | + function solu_adjoint(Δ) |
| 148 | + zerou = zero(sol.prob.u0) |
| 149 | + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) |
| 150 | + (build_solution(sol.prob, sol.alg, sol.t, _Δ),) |
| 151 | + end |
| 152 | + sol.u, solu_adjoint |
| 153 | +end |
| 154 | + |
| 155 | +@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution, |
| 156 | + ::Val{:u}) |
| 157 | + function solu_adjoint(Δ) |
| 158 | + zerou = zero(sol.prob.u0) |
| 159 | + _Δ = @. ifelse(Δ === nothing, zerou, Δ) |
| 160 | + (build_solution(sol.prob, sol.alg, _Δ, sol.resid),) |
| 161 | + end |
| 162 | + sol.u, solu_adjoint |
| 163 | +end |
| 164 | + |
| 165 | +@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution, |
| 166 | + ::Val{:u}) |
| 167 | + function solu_adjoint(Δ) |
| 168 | + zerou = zero(sol.u) |
| 169 | + _Δ = @. ifelse(Δ === nothing, zerou, Δ) |
| 170 | + (build_solution(sol.cache, sol.alg, _Δ, sol.objective),) |
| 171 | + end |
| 172 | + sol.u, solu_adjoint |
| 173 | +end |
| 174 | + |
| 175 | +function ∇tmap(cx, f, args...) |
| 176 | + ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...) |
| 177 | + if isempty(ys_and_backs) |
| 178 | + ys_and_backs, _ -> (NoTangent(), NoTangent()) |
| 179 | + else |
| 180 | + ys, backs = Zygote.unzip(ys_and_backs) |
| 181 | + function ∇tmap_internal(Δ) |
| 182 | + Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ) |
| 183 | + Δf_and_args = Zygote.unzip(Δf_and_args_zipped) |
| 184 | + Δf = reduce(Zygote.accum, Δf_and_args[1]) |
| 185 | + (Δf, Δf_and_args[2:end]...) |
| 186 | + end |
| 187 | + ys, ∇tmap_internal |
| 188 | + end |
| 189 | +end |
| 190 | + |
| 191 | +function ∇responsible_map(cx, f, args...) |
| 192 | + ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), |
| 193 | + args...) |
| 194 | + if isempty(ys_and_backs) |
| 195 | + ys_and_backs, _ -> (NoTangent(), NoTangent()) |
| 196 | + else |
| 197 | + ys, backs = Zygote.unzip(ys_and_backs) |
| 198 | + ys, |
| 199 | + function ∇responsible_map_internal(Δ) |
| 200 | + # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. |
| 201 | + Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), |
| 202 | + Zygote._tryreverse(SciMLBase.responsible_map, |
| 203 | + backs, Δ)...) |
| 204 | + Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, |
| 205 | + Δf_and_args_zipped)) |
| 206 | + Δf = reduce(Zygote.accum, Δf_and_args[1]) |
| 207 | + (Δf, Δf_and_args[2:end]...) |
| 208 | + end |
| 209 | + end |
| 210 | +end |
| 211 | + |
| 212 | +@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...) |
| 213 | + ∇tmap(__context__, f, args...) |
| 214 | +end |
| 215 | + |
| 216 | +@adjoint function SciMLBase.responsible_map(f, |
| 217 | + args::Union{AbstractArray, Tuple |
| 218 | + }...) |
| 219 | + ∇responsible_map(__context__, f, args...) |
78 | 220 | end
|
79 | 221 |
|
80 | 222 | end
|
0 commit comments