Skip to content

Commit 2368679

Browse files
committed
make oop mutation free
1 parent e05ccf8 commit 2368679

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -697,15 +697,14 @@ end
697697

698698
### Rodas4 methods
699699

700-
struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD, rateType} <: RosenbrockConstantCache
700+
struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
701701
tf::TF
702702
uf::UF
703703
tab::Tab
704704
J::JType
705705
W::WType
706706
linsolve::F
707707
autodiff::AD
708-
ks::Vector{rateType}
709708
end
710709

711710
tabtype(::Rodas4) = Rodas4Tableau
@@ -722,11 +721,10 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, u, rate_proto
722721
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
723722
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
724723
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
725-
ks = Vector{typeof(rate_prototype)}(undef, 6)
726724
Rodas4ConstantCache(tf, uf,
727725
tabtype(alg)(constvalue(uBottomEltypeNoUnits),
728726
constvalue(tTypeNoUnits)), J, W, linsolve,
729-
alg_autodiff(alg), ks)
727+
alg_autodiff(alg))
730728
end
731729

732730
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, u, rate_prototype, ::Type{uEltypeNoUnits},

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,7 @@ end
12081208

12091209
@muladd function perform_step!(integrator, cache::Rodas4ConstantCache, repeat_step = false)
12101210
(;t, dt, uprev, u, f, p) = integrator
1211-
(;tf, uf, ks) = cache
1211+
(;tf, uf) = cache
12121212
(;A, C, gamma, c, d, H) = cache.tab
12131213

12141214
# Precalculations
@@ -1228,11 +1228,15 @@ end
12281228
return nothing
12291229
end
12301230

1231-
# Initialize k arrays
1231+
# Initialize ks
12321232
num_stages = size(A,1)
1233-
1233+
du = f(u, p, t)
1234+
linsolve_tmp = @.. du + dtd[1] * dT
1235+
k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
1236+
# constant number for type stability make sure this is greater than num_stages
1237+
ks = ntuple(Returns(k1), 10)
12341238
# Loop for stages
1235-
for stage in 1:num_stages
1239+
for stage in 2:num_stages
12361240
u = uprev
12371241
for i in 1:stage-1
12381242
u = @.. u + A[stage, i] * ks[i]
@@ -1255,7 +1259,7 @@ end
12551259
end
12561260
linsolve_tmp = @.. du + dtd[stage] * dT + linsolve_tmp1
12571261

1258-
ks[stage] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
1262+
ks = Base.setindex(ks, _reshape(W \ -_vec(linsolve_tmp), axes(uprev)), stage)
12591263
integrator.stats.nsolve += 1
12601264
end
12611265
#@show ks
@@ -1271,7 +1275,7 @@ end
12711275
for j in eachindex(integrator.k)
12721276
integrator.k[j] = zero(integrator.k[1])
12731277
end
1274-
for i in eachindex(ks)
1278+
for i in 1:num_stages
12751279
for j in eachindex(integrator.k)
12761280
integrator.k[j] = @.. integrator.k[j] + H[j, i] * ks[i]
12771281
end

lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache,
291291
always_calc_begin = false, allow_calc_end = true,
292292
force_calc_end = false)
293293
if length(k) < 2 || always_calc_begin
294-
(;tf, uf, ks) = cache
294+
(;tf, uf) = cache
295295
(;A, C, gamma, c, d, H) = cache.tab
296296

297297
# Precalculations
@@ -317,10 +317,16 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache,
317317
J = ForwardDiff.derivative(uf, uprev)
318318
W = 1 / dtgamma - J
319319
end
320-
321-
num_stages = size(A, 1)
320+
321+
322+
num_stages = size(A,1)
323+
du = f(u, p, t)
324+
linsolve_tmp = @.. du + dtd[1] * dT
325+
k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
326+
# constant number for type stability make sure this is greater than num_stages
327+
ks = ntuple(Returns(k1), 10)
322328
# Last stage doesn't affect ks
323-
for stage in 1:num_stages-1
329+
for stage in 2:num_stages-1
324330
u = uprev
325331
for i in 1:stage-1
326332
u = @.. u + A[stage, i] * ks[i]
@@ -341,14 +347,14 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache,
341347
end
342348
end
343349

344-
ks[stage] = _reshape(W \ _vec(linsolve_tmp), axes(uprev))
350+
ks = Base.setindex(ks, _reshape(W \ _vec(linsolve_tmp), axes(uprev)), stage)
345351
end
346352

347353
k1 = zero(ks[1])
348354
k2 = zero(ks[1])
349355
H = cache.tab.H
350356
# Last stage doesn't affect ks
351-
for i in 1:length(ks)-1
357+
for i in 1:num_stages-1
352358
k1 = @.. k1 + H[1, i] * ks[i]
353359
k2 = @.. k2 + H[2, i] * ks[i]
354360
end

0 commit comments

Comments
 (0)