Skip to content

Commit 75168c8

Browse files
small edits
1 parent 9c15269 commit 75168c8

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,14 +520,14 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
520520
Convergence, J)
521521
end
522522

523-
mutable struct AdaptiveRadauCache{uType, cuType, uNoUnitsType, rateType, JType, W1Type, W2Type,
523+
mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
524524
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
525525
FIRKMutableCache
526526
u::uType
527527
uprev::uType
528528
z::Vector{uType}
529529
w::Vector{uType}
530-
c_prime::Vector{BigFloat}
530+
c_prime::Vector{tType}
531531
dw1::uType
532532
ubuff::uType
533533
dw2::Vector{cuType}
@@ -589,7 +589,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
589589
z[i] = w[i] = zero(u)
590590
end
591591

592-
c_prime = Vector{BigFloat}(undef, num_stages) #time stepping
592+
c_prime = Vector{typeof(t)}(undef, num_stages) #time stepping
593593

594594
dw1 = zero(u)
595595
ubuff = zero(u)

lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,12 +1498,20 @@ end
14981498

14991499
# transform `w` to `z`
15001500
#z = T * w
1501-
for i in 1:num_stages
1501+
for i in 1:num_stages - 1
15021502
z[i] = zero(u)
15031503
for j in 1:num_stages
15041504
z[i] += T[i,j] * w[j]
15051505
end
15061506
end
1507+
z[num_stages] = T[num_stages, 1] * w[1]
1508+
i = 2
1509+
while i < num_stages
1510+
z[num_stages] += w[i]
1511+
i += 2
1512+
end
1513+
1514+
15071515
# check stopping criterion
15081516
iter > 1 &&= θ / (1 - θ))
15091517
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
@@ -1524,13 +1532,16 @@ end
15241532
cache.iter = iter
15251533

15261534
u = @.. uprev + z[num_stages]
1527-
#=
1535+
15281536
if adaptive
15291537
edt = e ./ dt
1530-
tmp = @.. dot(edt, z)
1538+
tmp = dot(edt, z)
15311539
mass_matrix != I && (tmp = mass_matrix * tmp)
15321540
utilde = @.. broadcast=false integrator.fsalfirst+tmp
1533-
alg.smooth_est && (utilde = LU[1] \ utilde; integrator.stats.nsolve += 1)
1541+
if alg.smooth_est
1542+
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
1543+
integrator.stats.nsolve += 1
1544+
end
15341545
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
15351546
integrator.EEst = internalnorm(atmp, t)
15361547

@@ -1539,12 +1550,15 @@ end
15391550
f0 = f(uprev .+ utilde, p, t)
15401551
integrator.stats.nf += 1
15411552
utilde = @.. broadcast=false f0+tmp
1542-
alg.smooth_est && (utilde = LU[1] \ utilde; integrator.stats.nsolve += 1)
1553+
if alg.smooth_est
1554+
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
1555+
integrator.stats.nsolve += 1
1556+
end
15431557
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
15441558
integrator.EEst = internalnorm(atmp, t)
15451559
end
15461560
end
1547-
=#
1561+
15481562
if integrator.EEst <= oneunit(integrator.EEst)
15491563
cache.dtprev = dt
15501564
if alg.extrapolant != :constant
@@ -1729,12 +1743,19 @@ end
17291743

17301744
# transform `w` to `z`
17311745
#mul!(z, T, w)
1732-
for i in 1:num_stages
1746+
for i in 1:num_stages - 1
17331747
z[i] = zero(u)
17341748
for j in 1:num_stages
17351749
z[i] += T[i,j] * w[j]
17361750
end
17371751
end
1752+
z[num_stages] = T[num_stages, 1] * w[1]
1753+
i = 2
1754+
while i < num_stages
1755+
z[num_stages] += w[i]
1756+
i += 2
1757+
end
1758+
17381759
# check stopping criterion
17391760
iter > 1 &&= θ / (1 - θ))
17401761
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))

0 commit comments

Comments
 (0)