Skip to content

Commit b4487e9

Browse files
Improve float conversions in PI controllers
This is fairly hard to test, but it basically stems from `float(1//50)::Float64`, which means that this was always doing some float64 stuff, even when it should've been doing fastpow stuff. The test case just really required a Float32: ```julia using DiffEqCallbacks, OrdinaryDiffEq, Tracker Base.prevfloat(r::Tracker.TrackedReal) = Tracker.track(prevfloat, r) Tracker.@Grad function prevfloat(r::Real) prevfloat(Tracker.data(r)), Δ -> (Δ,) end Base.nextfloat(r::Tracker.TrackedReal) = Tracker.track(nextfloat, r) Tracker.@Grad function nextfloat(r::Real) nextfloat(Tracker.data(r)), Δ -> (Δ,) end function rober(u, p::TrackedArray, t) y₁, y₂, y₃ = u k₁, k₂, k₃ = p return Tracker.collect([-k₁ * y₁ + k₃ * y₂ * y₃, k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃, k₂ * y₂^2]) end p = TrackedArray([1.9f0, 1.0f0, 3.0f0]) u0 = TrackedArray([1.0f0, 0.0f0, 0.0f0]) tspan = TrackedArray([0.0f0, 1.0f0]) prob = ODEProblem{false}(rober, u0, tspan, p) p = TrackedArray([1.9f0, 1.0f0, 3.0f0]) u0 = TrackedArray([1.0f0, 0.0f0, 0.0f0]) tspan = TrackedArray([0.0f0, 1.0f0]) prob = ODEProblem{false}(rober, u0, tspan, p) saved_values = SavedValues(eltype(tspan), eltype(p)) cb = SavingCallback((u, t, integrator) -> integrator.EEst * integrator.dt, saved_values) solve(remake(prob, u0 = u0, p = p, tspan = tspan), Tsit5(), sensealg = SensitivityADPassThrough(), callback = cb) @test !all(iszero.(Tracker.gradient( p -> begin solve(remake(prob, u0 = u0, p = p, tspan = tspan), Tsit5(), sensealg = SensitivityADPassThrough(), callback = cb) return sum(saved_values.saveval) end, p)[1])) ``` Thus downstream tests with FastPower.jl used catches this, and it's somewhat hard to construct a case that's this sensitive to the type.
1 parent f82f6d8 commit b4487e9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lib/OrdinaryDiffEqCore/src/integrators/controllers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ end
141141
if iszero(EEst)
142142
q = inv(qmax)
143143
else
144-
q11 = FastPower.fastpower(EEst, float(beta1))
145-
q = q11 / FastPower.fastpower(qold, float(beta2))
144+
q11 = FastPower.fastpower(EEst, convert(typeof(EEst),beta1))
145+
q = q11 / FastPower.fastpower(qold, convert(typeof(EEst),beta2))
146146
integrator.q11 = q11
147147
@fastmath q = max(inv(qmax), min(inv(qmin), q / gamma))
148148
end

0 commit comments

Comments
 (0)