Skip to content

Commit 17b664c

Browse files
hotfix missing imports
1 parent 8b5e58d commit 17b664c

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com> and contributors"]
4-
version = "2.7.1"
4+
version = "2.7.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/SciMLBaseZygoteExt.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module SciMLBaseZygoteExt
22

33
using Zygote
4-
using Zygote: pullback, ZygoteRules
5-
using ZygoteRules: @adjoint
4+
using Zygote: @adjoint, pullback
5+
import Zygote: literal_getproperty
66
using SciMLBase
7-
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
7+
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
8+
getobserved, build_solution, EnsembleSolution,
9+
NonlinearSolution, AbstractTimeseriesSolution
810

911
# This method resolves the ambiguity with the pullback defined in
1012
# RecursiveArrayToolsZygoteExt
@@ -82,7 +84,7 @@ end
8284
VA[i], ODESolution_getindex_pullback
8385
end
8486

85-
@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
87+
@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
8688
::Val{:u})
8789
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
8890
end
@@ -140,32 +142,32 @@ end
140142
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
141143
end
142144

143-
@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution,
145+
@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
144146
::Val{:u})
145147
function solu_adjoint(Δ)
146148
zerou = zero(sol.prob.u0)
147149
= @. ifelse=== nothing, (zerou,), Δ)
148-
(DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
150+
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
149151
end
150152
sol.u, solu_adjoint
151153
end
152154

153-
@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution,
155+
@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution,
154156
::Val{:u})
155157
function solu_adjoint(Δ)
156158
zerou = zero(sol.prob.u0)
157159
= @. ifelse=== nothing, zerou, Δ)
158-
(DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
160+
(build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
159161
end
160162
sol.u, solu_adjoint
161163
end
162164

163-
@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution,
165+
@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution,
164166
::Val{:u})
165167
function solu_adjoint(Δ)
166168
zerou = zero(sol.u)
167169
= @. ifelse=== nothing, zerou, Δ)
168-
(DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
170+
(build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
169171
end
170172
sol.u, solu_adjoint
171173
end

0 commit comments

Comments
 (0)