Skip to content

Commit c48737e

Browse files
authored
Merge pull request #251 from ReactionMechanismGenerator/optimize_threaded_sensitivities
Optimize Threaded Sensitivities
2 parents a4ee101 + d86a2b8 commit c48737e

File tree

1 file changed

+65
-45
lines changed

1 file changed

+65
-45
lines changed

src/ThreadedSensitivities.jl

Lines changed: 65 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,46 @@ function threadedsensitivities(react; odesolver=nothing, senssolver=nothing,
3535
forwardsensitivities=true, forwarddiff=react.forwarddiff, modelingtoolkit=react.modelingtoolkit,
3636
tau=react.tau)
3737

38-
salist = generatesensitivityodes(react, sol)
3938

4039
# Parallelize the SA calculations
41-
solutiondictionary = Dict()
40+
solutions = Array{ODESolution}(undef,length(react.p))
41+
42+
nthreads = Threads.nthreads()
43+
if nthreads > 1 #each thread needs its own Reactor
44+
reacts = [deepcopy(react) for i in 1:nthreads]
45+
else
46+
reacts = [react]
47+
end
4248

4349
@threads for i in 1:length(react.p)
44-
s = solve(salist[i], senssolver; senskwargs...)
45-
solutiondictionary[i] = s
50+
if nthreads > 1
51+
id = Threads.threadid()
52+
r = reacts[id]
53+
else
54+
r = react
55+
end
56+
jacy!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q<:Real,V} = jacobiany!(J,y,p,t,r.domain,r.interfaces,nothing)
57+
jacp!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q<:Real,V} = jacobianp!(J,y,p,t,r.domain,r.interfaces,nothing)
58+
59+
function dsdt!(ds, s, local_params, t)
60+
jy = zeros(length(r.y0), length(r.y0))
61+
jp = zeros(length(r.y0), length(r.p))
62+
y = sol(t)
63+
jacy!(jy, y, r.p, t)
64+
jacp!(jp, y, r.p, t)
65+
@views @inbounds c = jp[:, i]
66+
@inbounds ds .= jy * s .+ c
67+
end
68+
69+
# Create list of ODEProblems for each batch of parameters
70+
71+
odefcn = ODEFunction(dsdt!)
72+
prob = ODEProblem(odefcn, zeros(length(r.y0)),r.tspan,0)
73+
s = solve(prob, senssolver; senskwargs...)
74+
solutions[i] = s
4675
end
47-
bigsol = generatesenssolution(sol, solutiondictionary, reactsens.ode)
76+
77+
bigsol = generatesenssolution(sol,solutions,reactsens.ode)
4878
return bigsol
4979
end
5080

@@ -83,65 +113,55 @@ function threadedsensitivities(react, paramindices; odesolver=nothing, senssolve
83113
forwardsensitivities=true, forwarddiff=react.forwarddiff, modelingtoolkit=react.modelingtoolkit,
84114
tau=react.tau)
85115

86-
salist = generatesensitivityodes(react, sol)
87116

88117
# Parallelize the SA calculations
89-
solutiondictionary = Dict()
90-
91-
@threads for i in paramindices
92-
s = solve(salist[i], senssolver; senskwargs...)
93-
solutiondictionary[i] = s
118+
solutions = Array{ODESolution}(undef,length(paramindices))
119+
nthreads = Threads.nthreads()
120+
if nthreads > 1 #each thread needs its own Reactor
121+
reacts = [deepcopy(react) for i in 1:nthreads]
122+
else
123+
reacts = [react]
94124
end
95-
96-
return solutiondictionary
97-
end
98-
99-
export threadedsensitivities
100-
101-
"""
102-
generate individual sensitivity ODEs for each parameter
103-
"""
104-
function generatesensitivityodes(react, sol)
105-
sa_list = []
106-
y0 = react.y0
107-
tspan = react.tspan
108-
p = react.p
109-
for i in 1:length(p)
110-
r = deepcopy(react)
111-
jacy!(J::Q2, y::T, p::V, t::Q) where {Q2,T,Q<:Real,V} = jacobiany!(J, y, p, t, r.domain, r.interfaces, nothing)
112-
jacp!(J::Q2, y::T, p::V, t::Q) where {Q2,T,Q<:Real,V} = jacobianp!(J, y, p, t, r.domain, r.interfaces, nothing)
113-
125+
@threads for n in 1:length(paramindices)
126+
i = paramindices[n]
127+
if nthreads > 1
128+
id = Threads.threadid()
129+
r = reacts[id]
130+
else
131+
r = react
132+
end
133+
jacy!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q<:Real,V} = jacobiany!(J,y,p,t,r.domain,r.interfaces,nothing)
134+
jacp!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q<:Real,V} = jacobianp!(J,y,p,t,r.domain,r.interfaces,nothing)
135+
114136
function dsdt!(ds, s, local_params, t)
115-
jy = zeros(length(y0), length(y0))
116-
jp = zeros(length(y0), length(p))
137+
jy = zeros(length(r.y0), length(r.y0))
138+
jp = zeros(length(r.y0), length(r.p))
117139
y = sol(t)
118-
jacy!(jy, y, p, t)
119-
jacp!(jp, y, p, t)
140+
jacy!(jy, y, r.p, t)
141+
jacp!(jp, y, r.p, t)
120142
@views @inbounds c = jp[:, i]
121143
@inbounds ds .= jy * s .+ c
122144
end
123145

124146
# Create list of ODEProblems for each batch of parameters
125147

126148
odefcn = ODEFunction(dsdt!)
127-
prob = ODEProblem(odefcn, zeros(length(y0)), tspan, 0)
128-
push!(sa_list, prob)
149+
prob = ODEProblem(odefcn, zeros(length(r.y0)),r.tspan,0)
150+
s = solve(prob, senssolver; senskwargs...)
151+
solutions[n] = s
129152
end
130-
return sa_list
153+
solutiondictionary = [i=>solutions[n] for (n,i) in enumerate(paramindices)]
154+
return solutiondictionary
131155
end
132156

157+
export threadedsensitivities
158+
133159
"""
134160
Combine ODE solutions into a sensitivity solution
135161
"""
136-
function generatesenssolution(sol, sensdict, sensprob)
162+
function generatesenssolution(sol, senssolns, sensprob)
137163
ts = sol.t
138-
ordkeys = sort([x for x in keys(sensdict)])
139-
u = deepcopy(sol.u)
140-
for k in ordkeys
141-
for i in 1:length(u)
142-
u[i] = vcat(u[i], sensdict[k](ts[i]))
143-
end
144-
end
164+
u = [vcat(sol.u[i],(senssolns[k](ts[i]) for k in 1:length(senssolns))...) for i in 1:length(sol.u)]
145165
bigsol = build_solution(sensprob, sol.alg, ts, u;
146166
interp=LinearInterpolation(ts, u), retcode=sol.retcode)
147167
return bigsol

0 commit comments

Comments
 (0)