Skip to content

Commit ecdcc33

Browse files
Merge pull request #537 from SciML/rules
Move over the rest of pirating rules
2 parents 5a7771d + b960eae commit ecdcc33

File tree

5 files changed

+215
-37
lines changed

5 files changed

+215
-37
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "2.6.0"
66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1110
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1211
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -31,16 +30,17 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3130
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3231
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3332
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
34-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3533

3634
[weakdeps]
35+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3736
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
3837
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
3938
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
4039
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
4140
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4241

4342
[extensions]
43+
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
4444
SciMLBasePartialFunctionsExt = "PartialFunctions"
4545
SciMLBasePyCallExt = "PyCall"
4646
SciMLBasePythonCallExt = "PythonCall"
@@ -78,11 +78,12 @@ Statistics = "1"
7878
SymbolicIndexingInterface = "0.2"
7979
Tables = "1"
8080
TruncatedStacktraces = "1"
81-
ZygoteRules = "0.2"
81+
Zygote = "0.6"
8282
julia = "1.6"
8383

8484
[extras]
8585
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
86+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8687
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
8788
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8889
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"

src/solutions/chainrules.jl renamed to ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
module SciMLBaseChainRulesCoreExt
2+
3+
import ChainRulesCore
4+
import ChainRulesCore: NoTangent, @non_differentiable
5+
16
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
27
>:ChainRulesCore.HasReverseMode,
38
},
@@ -70,3 +75,60 @@ function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
7075
end
7176
VA[sym], ODESolution_getindex_pullback
7277
end
78+
79+
function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
80+
function ODEProblemAdjoint(ȳ)
81+
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
82+
end
83+
84+
ODEProblem(args...; kwargs...), ODEProblemAdjoint
85+
end
86+
87+
function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
88+
function SDEProblemAdjoint(ȳ)
89+
(NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
90+
end
91+
92+
SDEProblem(args...; kwargs...), SDEProblemAdjoint
93+
end
94+
95+
function ChainRulesCore.rrule(::Type{
96+
<:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
97+
T11, T12,
98+
}}, u,
99+
args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11,
100+
T12}
101+
function ODESolutionAdjoint(ȳ)
102+
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
103+
end
104+
105+
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
106+
ODESolutionAdjoint
107+
end
108+
109+
function ChainRulesCore.rrule(::Type{
110+
<:ODESolution{uType, tType, isinplace, P, NP, F, G, K,
111+
ND,
112+
}}, u,
113+
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
114+
function SDESolutionAdjoint(ȳ)
115+
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
116+
end
117+
118+
SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
119+
end
120+
121+
function ChainRulesCore.rrule(::DiffEqBase.EnsembleSolution, sim, time, converged)
122+
out = EnsembleSolution(sim, time, converged)
123+
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
124+
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
125+
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
126+
(NoTangent(), EnsembleSolution(arrarr, 0.0, true), NoTangent(), NoTangent())
127+
end
128+
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
129+
(NoTangent(), p̄, NoTangent(), NoTangent())
130+
end
131+
out, EnsembleSolution_adjoint
132+
end
133+
134+
end

ext/SciMLBaseZygoteExt.jl

Lines changed: 149 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module SciMLBaseZygoteExt
22

3-
using Zygote: pullback
3+
using Zygote
4+
using Zygote: pullback, ZygoteRules
45
using ZygoteRules: @adjoint
5-
import ZygoteRules
6-
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved
6+
using SciMLBase
7+
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
78

89
# This method resolves the ambiguity with the pullback defined in
910
# RecursiveArrayToolsZygoteExt
@@ -56,25 +57,164 @@ end
5657
VA[sym, j], ODESolution_getindex_pullback
5758
end
5859

59-
ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats)
60-
out = EnsembleSolution(sim, time, converged)
60+
@adjoint function EnsembleSolution(sim, time, converged, stats)
61+
out = EnsembleSolution(sim, time, converged, stats)
6162
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
6263
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
6364
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
64-
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing)
65+
(EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing)
6566
end
6667
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
67-
(EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing)
68+
(EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing)
6869
end
6970
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
7071
(p̄, nothing, nothing, nothing)
7172
end
7273
out, EnsembleSolution_adjoint
7374
end
7475

75-
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
76+
@adjoint function getindex(VA::ODESolution, i::Int)
77+
function ODESolution_getindex_pullback(Δ)
78+
Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x)))
79+
for (x, j) in zip(VA.u, 1:length(VA))]
80+
(Δ′, nothing)
81+
end
82+
VA[i], ODESolution_getindex_pullback
83+
end
84+
85+
@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
7686
::Val{:u})
77-
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)
87+
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
88+
end
89+
90+
@adjoint function getindex(VA::ODESolution, sym)
91+
function ODESolution_getindex_pullback(Δ)
92+
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
93+
if i === nothing
94+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
95+
else
96+
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
97+
for (x, j) in zip(VA.u, 1:length(VA))]
98+
(Δ′, nothing)
99+
end
100+
end
101+
VA[sym], ODESolution_getindex_pullback
102+
end
103+
104+
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
105+
}(u,
106+
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
107+
T9, T10, T11, T12}
108+
function ODESolutionAdjoint(ȳ)
109+
(ȳ, ntuple(_ -> nothing, length(args))...)
110+
end
111+
112+
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
113+
ODESolutionAdjoint
114+
end
115+
116+
@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u,
117+
args...) where
118+
{uType, tType, isinplace, P, NP, F, G, K, ND}
119+
function SDESolutionAdjoint(ȳ)
120+
(ȳ, ntuple(_ -> nothing, length(args))...)
121+
end
122+
123+
SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
124+
end
125+
126+
@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,
127+
args...) where {
128+
T,
129+
N,
130+
uType,
131+
R,
132+
P,
133+
A,
134+
O,
135+
uType2,
136+
}
137+
function NonlinearSolutionAdjoint(ȳ)
138+
(ȳ, ntuple(_ -> nothing, length(args))...)
139+
end
140+
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
141+
end
142+
143+
@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution,
144+
::Val{:u})
145+
function solu_adjoint(Δ)
146+
zerou = zero(sol.prob.u0)
147+
= @. ifelse=== nothing, (zerou,), Δ)
148+
(DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
149+
end
150+
sol.u, solu_adjoint
151+
end
152+
153+
@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution,
154+
::Val{:u})
155+
function solu_adjoint(Δ)
156+
zerou = zero(sol.prob.u0)
157+
= @. ifelse=== nothing, zerou, Δ)
158+
(DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
159+
end
160+
sol.u, solu_adjoint
161+
end
162+
163+
@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution,
164+
::Val{:u})
165+
function solu_adjoint(Δ)
166+
zerou = zero(sol.u)
167+
= @. ifelse=== nothing, zerou, Δ)
168+
(DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
169+
end
170+
sol.u, solu_adjoint
171+
end
172+
173+
function ∇tmap(cx, f, args...)
174+
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
175+
if isempty(ys_and_backs)
176+
ys_and_backs, _ -> (NoTangent(), NoTangent())
177+
else
178+
ys, backs = Zygote.unzip(ys_and_backs)
179+
function ∇tmap_internal(Δ)
180+
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
181+
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
182+
Δf = reduce(Zygote.accum, Δf_and_args[1])
183+
(Δf, Δf_and_args[2:end]...)
184+
end
185+
ys, ∇tmap_internal
186+
end
187+
end
188+
189+
function ∇responsible_map(cx, f, args...)
190+
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
191+
args...)
192+
if isempty(ys_and_backs)
193+
ys_and_backs, _ -> (NoTangent(), NoTangent())
194+
else
195+
ys, backs = Zygote.unzip(ys_and_backs)
196+
ys,
197+
function ∇responsible_map_internal(Δ)
198+
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
199+
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
200+
Zygote._tryreverse(SciMLBase.responsible_map,
201+
backs, Δ)...)
202+
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
203+
Δf_and_args_zipped))
204+
Δf = reduce(Zygote.accum, Δf_and_args[1])
205+
(Δf, Δf_and_args[2:end]...)
206+
end
207+
end
208+
end
209+
210+
@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
211+
∇tmap(__context__, f, args...)
212+
end
213+
214+
@adjoint function SciMLBase.responsible_map(f,
215+
args::Union{AbstractArray, Tuple
216+
}...)
217+
∇responsible_map(__context__, f, args...)
78218
end
79219

80220
end

src/SciMLBase.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import RuntimeGeneratedFunctions
2222
import EnumX
2323
import TruncatedStacktraces
2424
import ADTypes: AbstractADType
25-
import ChainRulesCore
26-
import ZygoteRules: @adjoint
2725
import FillArrays
2826

2927
using Reexport
@@ -716,7 +714,6 @@ include("solutions/optimization_solutions.jl")
716714
include("solutions/dae_solutions.jl")
717715
include("solutions/pde_solutions.jl")
718716
include("solutions/solution_interface.jl")
719-
include("solutions/zygote.jl")
720717

721718
include("ensemble/ensemble_solutions.jl")
722719
include("ensemble/ensemble_problems.jl")

src/solutions/zygote.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)