|
1 | 1 | module SciMLBaseZygoteExt
|
2 | 2 |
|
3 |
| -using Zygote: pullback |
| 3 | +using Zygote |
| 4 | +using Zygote: pullback, ZygoteRules |
4 | 5 | using ZygoteRules: @adjoint
|
5 |
| -import ZygoteRules |
6 |
| -using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved |
| 6 | +using SciMLBase |
| 7 | +using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved |
7 | 8 |
|
8 | 9 | # This method resolves the ambiguity with the pullback defined in
|
9 | 10 | # RecursiveArrayToolsZygoteExt
|
|
56 | 57 | VA[sym, j], ODESolution_getindex_pullback
|
57 | 58 | end
|
58 | 59 |
|
59 |
| -ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats) |
60 |
| - out = EnsembleSolution(sim, time, converged) |
| 60 | +@adjoint function EnsembleSolution(sim, time, converged, stats) |
| 61 | + out = EnsembleSolution(sim, time, converged, stats) |
61 | 62 | function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
|
62 | 63 | arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
|
63 | 64 | for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
|
64 |
| - (EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing) |
| 65 | + (EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing) |
65 | 66 | end
|
66 | 67 | function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
|
67 |
| - (EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing) |
| 68 | + (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) |
68 | 69 | end
|
69 | 70 | function EnsembleSolution_adjoint(p̄::EnsembleSolution)
|
70 | 71 | (p̄, nothing, nothing, nothing)
|
71 | 72 | end
|
72 | 73 | out, EnsembleSolution_adjoint
|
73 | 74 | end
|
74 | 75 |
|
75 |
| -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, |
| 76 | +@adjoint function getindex(VA::ODESolution, i::Int) |
| 77 | + function ODESolution_getindex_pullback(Δ) |
| 78 | + Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) |
| 79 | + for (x, j) in zip(VA.u, 1:length(VA))] |
| 80 | + (Δ′, nothing) |
| 81 | + end |
| 82 | + VA[i], ODESolution_getindex_pullback |
| 83 | +end |
| 84 | + |
| 85 | +@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, |
76 | 86 | ::Val{:u})
|
77 |
| - sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),) |
| 87 | + sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) |
| 88 | +end |
| 89 | + |
| 90 | +@adjoint function getindex(VA::ODESolution, sym) |
| 91 | + function ODESolution_getindex_pullback(Δ) |
| 92 | + i = issymbollike(sym) ? sym_to_index(sym, VA) : sym |
| 93 | + if i === nothing |
| 94 | + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) |
| 95 | + else |
| 96 | + Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] |
| 97 | + for (x, j) in zip(VA.u, 1:length(VA))] |
| 98 | + (Δ′, nothing) |
| 99 | + end |
| 100 | + end |
| 101 | + VA[sym], ODESolution_getindex_pullback |
| 102 | +end |
| 103 | + |
| 104 | +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 |
| 105 | + }(u, |
| 106 | + args...) where {T1, T2, T3, T4, T5, T6, T7, T8, |
| 107 | + T9, T10, T11, T12} |
| 108 | + function ODESolutionAdjoint(ȳ) |
| 109 | + (ȳ, ntuple(_ -> nothing, length(args))...) |
| 110 | + end |
| 111 | + |
| 112 | + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), |
| 113 | + ODESolutionAdjoint |
| 114 | +end |
| 115 | + |
| 116 | +@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, |
| 117 | + args...) where |
| 118 | + {uType, tType, isinplace, P, NP, F, G, K, ND} |
| 119 | + function SDESolutionAdjoint(ȳ) |
| 120 | + (ȳ, ntuple(_ -> nothing, length(args))...) |
| 121 | + end |
| 122 | + |
| 123 | + SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint |
| 124 | +end |
| 125 | + |
| 126 | +@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, |
| 127 | + args...) where { |
| 128 | + T, |
| 129 | + N, |
| 130 | + uType, |
| 131 | + R, |
| 132 | + P, |
| 133 | + A, |
| 134 | + O, |
| 135 | + uType2, |
| 136 | +} |
| 137 | + function NonlinearSolutionAdjoint(ȳ) |
| 138 | + (ȳ, ntuple(_ -> nothing, length(args))...) |
| 139 | + end |
| 140 | + NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint |
| 141 | +end |
| 142 | + |
| 143 | +@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution, |
| 144 | + ::Val{:u}) |
| 145 | + function solu_adjoint(Δ) |
| 146 | + zerou = zero(sol.prob.u0) |
| 147 | + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) |
| 148 | + (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) |
| 149 | + end |
| 150 | + sol.u, solu_adjoint |
| 151 | +end |
| 152 | + |
| 153 | +@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution, |
| 154 | + ::Val{:u}) |
| 155 | + function solu_adjoint(Δ) |
| 156 | + zerou = zero(sol.prob.u0) |
| 157 | + _Δ = @. ifelse(Δ === nothing, zerou, Δ) |
| 158 | + (DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),) |
| 159 | + end |
| 160 | + sol.u, solu_adjoint |
| 161 | +end |
| 162 | + |
| 163 | +@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution, |
| 164 | + ::Val{:u}) |
| 165 | + function solu_adjoint(Δ) |
| 166 | + zerou = zero(sol.u) |
| 167 | + _Δ = @. ifelse(Δ === nothing, zerou, Δ) |
| 168 | + (DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),) |
| 169 | + end |
| 170 | + sol.u, solu_adjoint |
| 171 | +end |
| 172 | + |
| 173 | +function ∇tmap(cx, f, args...) |
| 174 | + ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...) |
| 175 | + if isempty(ys_and_backs) |
| 176 | + ys_and_backs, _ -> (NoTangent(), NoTangent()) |
| 177 | + else |
| 178 | + ys, backs = Zygote.unzip(ys_and_backs) |
| 179 | + function ∇tmap_internal(Δ) |
| 180 | + Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ) |
| 181 | + Δf_and_args = Zygote.unzip(Δf_and_args_zipped) |
| 182 | + Δf = reduce(Zygote.accum, Δf_and_args[1]) |
| 183 | + (Δf, Δf_and_args[2:end]...) |
| 184 | + end |
| 185 | + ys, ∇tmap_internal |
| 186 | + end |
| 187 | +end |
| 188 | + |
| 189 | +function ∇responsible_map(cx, f, args...) |
| 190 | + ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), |
| 191 | + args...) |
| 192 | + if isempty(ys_and_backs) |
| 193 | + ys_and_backs, _ -> (NoTangent(), NoTangent()) |
| 194 | + else |
| 195 | + ys, backs = Zygote.unzip(ys_and_backs) |
| 196 | + ys, |
| 197 | + function ∇responsible_map_internal(Δ) |
| 198 | + # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. |
| 199 | + Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), |
| 200 | + Zygote._tryreverse(SciMLBase.responsible_map, |
| 201 | + backs, Δ)...) |
| 202 | + Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, |
| 203 | + Δf_and_args_zipped)) |
| 204 | + Δf = reduce(Zygote.accum, Δf_and_args[1]) |
| 205 | + (Δf, Δf_and_args[2:end]...) |
| 206 | + end |
| 207 | + end |
| 208 | +end |
| 209 | + |
| 210 | +@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...) |
| 211 | + ∇tmap(__context__, f, args...) |
| 212 | +end |
| 213 | + |
| 214 | +@adjoint function SciMLBase.responsible_map(f, |
| 215 | + args::Union{AbstractArray, Tuple |
| 216 | + }...) |
| 217 | + ∇responsible_map(__context__, f, args...) |
78 | 218 | end
|
79 | 219 |
|
80 | 220 | end
|
0 commit comments