Skip to content

Commit 3456aec

Browse files
committed
fix oop BDF gamma type and terk_tmp type
1 parent cb87c1c commit 3456aec

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

lib/OrdinaryDiffEqBDF/src/algorithms.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ end
150150
function QNDF1(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
151151
concrete_jac = nothing, diff_type = Val{:forward},
152152
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
153-
extrapolant = :linear, kappa = -0.1850,
153+
extrapolant = :linear, kappa = -37//200,
154154
controller = :Standard, step_limiter! = trivial_limiter!)
155155
QNDF1{
156156
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
@@ -233,7 +233,7 @@ function QNDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
233233
diff_type = Val{:forward},
234234
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
235235
tol = nothing,
236-
extrapolant = :linear, kappa = promote(-0.1850, -1 // 9, -0.0823, -0.0415, 0),
236+
extrapolant = :linear, kappa = (-37//200, -1//9, -823//10000, -83//2000, 0//1),
237237
controller = :Standard, step_limiter! = trivial_limiter!) where {MO}
238238
QNDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
239239
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),

lib/OrdinaryDiffEqBDF/src/bdf_caches.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ function alg_cache(alg::QNDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
358358
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits
359359
} where {MO}
360360
max_order = MO
361-
γ, c = one(eltype(alg.kappa)), 1
361+
γ, c = one(uEltypeNoUnits), 1
362362
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
363363
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
364364
dtprev = one(dt)
@@ -539,7 +539,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
539539
dt, reltol, p, calck,
540540
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits
541541
} where {MO}
542-
γ, c = 1.0, 1.0
542+
γ, c = one(uEltypeNoUnits), 1
543543
max_order = MO
544544
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
545545
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
@@ -614,7 +614,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
614614
dt, reltol, p, calck,
615615
::Val{true}) where {MO, uEltypeNoUnits, uBottomEltypeNoUnits,
616616
tTypeNoUnits}
617-
γ, c = 1.0, 1.0
617+
γ, c = one(uEltypeNoUnits), 1
618618
fsalfirst = zero(rate_prototype)
619619
max_order = MO
620620
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,

lib/OrdinaryDiffEqBDF/src/controllers.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function choose_order!(alg::FBDF, integrator,
172172
terk_tmp = @.. broadcast=false fd_weights[k - 2, 1]*u
173173
vc = _vec(terk_tmp)
174174
for i in 2:(k - 2)
175-
@.. broadcast=false @views vc += fd_weights[i, k - 2] * u_history[:, i - 1]
175+
@.. @views vc += fd_weights[i, k - 2] * u_history[:, i - 1]
176176
end
177177
@.. broadcast=false terk_tmp*=abs(dt^(k - 2))
178178
calculate_residuals!(atmp, _vec(terk_tmp), _vec(uprev), _vec(u),
@@ -204,22 +204,24 @@ function choose_order!(alg::FBDF, integrator,
204204
terkm1 = terkm2
205205
fd_weights = calc_finite_difference_weights(ts_tmp, t + dt, k - 2,
206206
Val(max_order))
207-
terk_tmp = @.. broadcast=false fd_weights[k - 2, 1]*u
207+
local terk_tmp
208208
if u isa Number
209+
terk_tmp = fd_weights[k - 2, 1]*u
209210
for i in 2:(k - 2)
210211
terk_tmp += fd_weights[i, k - 2] * u_history[i - 1]
211212
end
212213
terk_tmp *= abs(dt^(k - 2))
213214
else
214-
vc = _vec(terk_tmp)
215+
# we need terk_tmp to be mutable.
216+
# so it can be updated
217+
terk_tmp = similar(u)
218+
@.. terk_tmp = fd_weights[k - 2, 1]*_vec(u)
215219
for i in 2:(k - 2)
216-
@.. broadcast=false @views vc += fd_weights[i, k - 2] *
217-
u_history[:, i - 1]
220+
@.. @views terk_tmp += fd_weights[i, k - 2] * u_history[:, i - 1]
218221
end
219-
terk_tmp = reshape(vc, size(terk_tmp))
220-
terk_tmp *= @.. broadcast=false abs(dt^(k - 2))
222+
@.. terk_tmp *= abs(dt^(k - 2))
221223
end
222-
atmp = calculate_residuals(_vec(terk_tmp), _vec(uprev), _vec(u),
224+
atmp = calculate_residuals(terk_tmp, _vec(uprev), _vec(u),
223225
integrator.opts.abstol, integrator.opts.reltol,
224226
integrator.opts.internalnorm, t)
225227
terkm2 = integrator.opts.internalnorm(atmp, t)

test/interface/linear_solver_test.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,14 @@ end
161161
using OrdinaryDiffEq, StaticArrays, LinearSolve, ParameterizedFunctions
162162

163163
hires = @ode_def Hires begin
164-
dy1 = -1.71 * y1 + 0.43 * y2 + 8.32 * y3 + 0.0007
165-
dy2 = 1.71 * y1 - 8.75 * y2
166-
dy3 = -10.03 * y3 + 0.43 * y4 + 0.035 * y5
167-
dy4 = 8.32 * y2 + 1.71 * y3 - 1.12 * y4
168-
dy5 = -1.745 * y5 + 0.43 * y6 + 0.43 * y7
169-
dy6 = -280.0 * y6 * y8 + 0.69 * y4 + 1.71 * y5 - 0.43 * y6 + 0.69 * y7
170-
dy7 = 280.0 * y6 * y8 - 1.81 * y7
171-
dy8 = -280.0 * y6 * y8 + 1.81 * y7
164+
dy1 = -1.71f0 * y1 + 0.43f0 * y2 + 8.32f0 * y3 + 0.0007f0 + 1f-18*t
165+
dy2 = 1.71f0 * y1 - 8.75f0 * y2
166+
dy3 = -10.03f0 * y3 + 0.43f0 * y4 + 0.035f0 * y5
167+
dy4 = 8.32f0 * y2 + 1.71f0 * y3 - 1.12f0 * y4
168+
dy5 = -1.745f0 * y5 + 0.43f0 * y6 + 0.43f0 * y7
169+
dy6 = -280.0f0 * y6 * y8 + 0.69f0 * y4 + 1.71f0 * y5 - 0.43f0 * y6 + 0.69f0 * y7
170+
dy7 = 280.0f0 * y6 * y8 - 1.81f0 * y7
171+
dy8 = -280.0f0 * y6 * y8 + 1.81f0 * y7
172172
end
173173

174174
u0 = zeros(8)
@@ -178,7 +178,11 @@ u0[8] = 0.0057
178178
probiip = ODEProblem{true}(hires, u0, (0.0, 10.0))
179179
proboop = ODEProblem{false}(hires, u0, (0.0, 10.0))
180180
probstatic = ODEProblem{false}(hires, SVector{8}(u0), (0.0, 10.0))
181+
probiipf32 = ODEProblem{true}(hires, Float32.(u0), (0f0, 10f0))
182+
proboopf32 = ODEProblem{false}(hires, Float32.(u0), (0f0, 10f0))
183+
probstaticf32 = ODEProblem{false}(hires, SVector{8}(Float32.(u0)), (0f0, 10f0))
181184
probs = (; probiip, proboop, probstatic)
185+
probsf32 = (;probiipf32, proboopf32, probstaticf32)
182186
qndf = QNDF()
183187
krylov_qndf = QNDF(linsolve = KrylovJL_GMRES())
184188
fbdf = FBDF()
@@ -197,3 +201,13 @@ refsol = solve(probiip, FBDF(), abstol = 1e-12, reltol = 1e-12)
197201
end
198202
end
199203
end
204+
205+
@testset "Hires Float32 calc_W tests" begin
206+
@testset "$probname" for (probname, prob) in pairs(probsf32)
207+
@testset "$solname" for (solname, solver) in pairs(solvers)
208+
sol = solve(prob, solver, maxiters = 2e4)
209+
@test sol.retcode == ReturnCode.Success
210+
@test isapprox(sol.u[end], refsol.u[end], rtol = 1e-2, atol = 1e-5)
211+
end
212+
end
213+
end

0 commit comments

Comments
 (0)