Skip to content

Commit d86a2b8

Browse files
committed
use arrays rather than dictionaries when threading
1 parent 9d7db3d commit d86a2b8

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/ThreadedSensitivities.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function threadedsensitivities(react; odesolver=nothing, senssolver=nothing,
3737

3838

3939
# Parallelize the SA calculations
40-
solutiondictionary = Dict()
40+
solutions = Array{ODESolution}(undef,length(react.p))
4141

4242
nthreads = Threads.nthreads()
4343
if nthreads > 1 #each thread needs its own Reactor
@@ -71,10 +71,10 @@ function threadedsensitivities(react; odesolver=nothing, senssolver=nothing,
7171
odefcn = ODEFunction(dsdt!)
7272
prob = ODEProblem(odefcn, zeros(length(r.y0)),r.tspan,0)
7373
s = solve(prob, senssolver; senskwargs...)
74-
solutiondictionary[i] = s
74+
solutions[i] = s
7575
end
7676

77-
bigsol = generatesenssolution(sol,solutiondictionary,reactsens.ode)
77+
bigsol = generatesenssolution(sol,solutions,reactsens.ode)
7878
return bigsol
7979
end
8080

@@ -115,14 +115,15 @@ function threadedsensitivities(react, paramindices; odesolver=nothing, senssolve
115115

116116

117117
# Parallelize the SA calculations
118-
solutiondictionary = Dict()
118+
solutions = Array{ODESolution}(undef,length(paramindices))
119119
nthreads = Threads.nthreads()
120120
if nthreads > 1 #each thread needs its own Reactor
121121
reacts = [deepcopy(react) for i in 1:nthreads]
122122
else
123123
reacts = [react]
124124
end
125-
@threads for i in paramindices
125+
@threads for n in 1:length(paramindices)
126+
i = paramindices[n]
126127
if nthreads > 1
127128
id = Threads.threadid()
128129
r = reacts[id]
@@ -147,9 +148,9 @@ function threadedsensitivities(react, paramindices; odesolver=nothing, senssolve
147148
odefcn = ODEFunction(dsdt!)
148149
prob = ODEProblem(odefcn, zeros(length(r.y0)),r.tspan,0)
149150
s = solve(prob, senssolver; senskwargs...)
150-
solutiondictionary[i] = s
151+
solutions[n] = s
151152
end
152-
153+
solutiondictionary = [i=>solutions[n] for (n,i) in enumerate(paramindices)]
153154
return solutiondictionary
154155
end
155156

@@ -158,11 +159,10 @@ export threadedsensitivities
158159
"""
159160
Combine ODE solutions into a sensitivity solution
160161
"""
161-
function generatesenssolution(sol, sensdict, sensprob)
162+
function generatesenssolution(sol, senssolns, sensprob)
162163
ts = sol.t
163-
ordkeys = sort([x for x in keys(sensdict)])
164+
u = [vcat(sol.u[i],(senssolns[k](ts[i]) for k in 1:length(senssolns))...) for i in 1:length(sol.u)]
164165
bigsol = build_solution(sensprob, sol.alg, ts, u;
165166
interp=LinearInterpolation(ts, u), retcode=sol.retcode)
166-
u = [vcat(sol.u[i],(sensdict[k](ts[i]) for k in ordkeys)...) for i in 1:length(sol.u)]
167167
return bigsol
168168
end

0 commit comments

Comments
 (0)