@@ -35,16 +35,46 @@ function threadedsensitivities(react; odesolver=nothing, senssolver=nothing,
35
35
forwardsensitivities= true , forwarddiff= react. forwarddiff, modelingtoolkit= react. modelingtoolkit,
36
36
tau= react. tau)
37
37
38
- salist = generatesensitivityodes (react, sol)
39
38
40
39
# Parallelize the SA calculations
41
- solutiondictionary = Dict ()
40
+ solutions = Array {ODESolution} (undef,length (react. p))
41
+
42
+ nthreads = Threads. nthreads ()
43
+ if nthreads > 1 # each thread needs its own Reactor
44
+ reacts = [deepcopy (react) for i in 1 : nthreads]
45
+ else
46
+ reacts = [react]
47
+ end
42
48
43
49
@threads for i in 1 : length (react. p)
44
- s = solve (salist[i], senssolver; senskwargs... )
45
- solutiondictionary[i] = s
50
+ if nthreads > 1
51
+ id = Threads. threadid ()
52
+ r = reacts[id]
53
+ else
54
+ r = react
55
+ end
56
+ jacy! (J:: Q2 ,y:: T ,p:: V ,t:: Q ) where {Q2,T,Q<: Real ,V} = jacobiany! (J,y,p,t,r. domain,r. interfaces,nothing )
57
+ jacp! (J:: Q2 ,y:: T ,p:: V ,t:: Q ) where {Q2,T,Q<: Real ,V} = jacobianp! (J,y,p,t,r. domain,r. interfaces,nothing )
58
+
59
+ function dsdt! (ds, s, local_params, t)
60
+ jy = zeros (length (r. y0), length (r. y0))
61
+ jp = zeros (length (r. y0), length (r. p))
62
+ y = sol (t)
63
+ jacy! (jy, y, r. p, t)
64
+ jacp! (jp, y, r. p, t)
65
+ @views @inbounds c = jp[:, i]
66
+ @inbounds ds .= jy * s .+ c
67
+ end
68
+
69
+ # Create list of ODEProblems for each batch of parameters
70
+
71
+ odefcn = ODEFunction (dsdt!)
72
+ prob = ODEProblem (odefcn, zeros (length (r. y0)),r. tspan,0 )
73
+ s = solve (prob, senssolver; senskwargs... )
74
+ solutions[i] = s
46
75
end
47
- bigsol = generatesenssolution (sol, solutiondictionary, reactsens. ode)
76
+
77
+ bigsol = generatesenssolution (sol,solutions,reactsens. ode)
48
78
return bigsol
49
79
end
50
80
@@ -83,65 +113,55 @@ function threadedsensitivities(react, paramindices; odesolver=nothing, senssolve
83
113
forwardsensitivities= true , forwarddiff= react. forwarddiff, modelingtoolkit= react. modelingtoolkit,
84
114
tau= react. tau)
85
115
86
- salist = generatesensitivityodes (react, sol)
87
116
88
117
# Parallelize the SA calculations
89
- solutiondictionary = Dict ()
90
-
91
- @threads for i in paramindices
92
- s = solve (salist[i], senssolver; senskwargs... )
93
- solutiondictionary[i] = s
118
+ solutions = Array {ODESolution} (undef,length (paramindices))
119
+ nthreads = Threads. nthreads ()
120
+ if nthreads > 1 # each thread needs its own Reactor
121
+ reacts = [deepcopy (react) for i in 1 : nthreads]
122
+ else
123
+ reacts = [react]
94
124
end
95
-
96
- return solutiondictionary
97
- end
98
-
99
- export threadedsensitivities
100
-
101
- """
102
- generate individual sensitivity ODEs for each parameter
103
- """
104
- function generatesensitivityodes (react, sol)
105
- sa_list = []
106
- y0 = react. y0
107
- tspan = react. tspan
108
- p = react. p
109
- for i in 1 : length (p)
110
- r = deepcopy (react)
111
- jacy! (J:: Q2 , y:: T , p:: V , t:: Q ) where {Q2,T,Q<: Real ,V} = jacobiany! (J, y, p, t, r. domain, r. interfaces, nothing )
112
- jacp! (J:: Q2 , y:: T , p:: V , t:: Q ) where {Q2,T,Q<: Real ,V} = jacobianp! (J, y, p, t, r. domain, r. interfaces, nothing )
113
-
125
+ @threads for n in 1 : length (paramindices)
126
+ i = paramindices[n]
127
+ if nthreads > 1
128
+ id = Threads. threadid ()
129
+ r = reacts[id]
130
+ else
131
+ r = react
132
+ end
133
+ jacy! (J:: Q2 ,y:: T ,p:: V ,t:: Q ) where {Q2,T,Q<: Real ,V} = jacobiany! (J,y,p,t,r. domain,r. interfaces,nothing )
134
+ jacp! (J:: Q2 ,y:: T ,p:: V ,t:: Q ) where {Q2,T,Q<: Real ,V} = jacobianp! (J,y,p,t,r. domain,r. interfaces,nothing )
135
+
114
136
function dsdt! (ds, s, local_params, t)
115
- jy = zeros (length (y0), length (y0))
116
- jp = zeros (length (y0), length (p))
137
+ jy = zeros (length (r . y0), length (r . y0))
138
+ jp = zeros (length (r . y0), length (r . p))
117
139
y = sol (t)
118
- jacy! (jy, y, p, t)
119
- jacp! (jp, y, p, t)
140
+ jacy! (jy, y, r . p, t)
141
+ jacp! (jp, y, r . p, t)
120
142
@views @inbounds c = jp[:, i]
121
143
@inbounds ds .= jy * s .+ c
122
144
end
123
145
124
146
# Create list of ODEProblems for each batch of parameters
125
147
126
148
odefcn = ODEFunction (dsdt!)
127
- prob = ODEProblem (odefcn, zeros (length (y0)), tspan, 0 )
128
- push! (sa_list, prob)
149
+ prob = ODEProblem (odefcn, zeros (length (r. y0)),r. tspan,0 )
150
+ s = solve (prob, senssolver; senskwargs... )
151
+ solutions[n] = s
129
152
end
130
- return sa_list
153
+ solutiondictionary = [i=> solutions[n] for (n,i) in enumerate (paramindices)]
154
+ return solutiondictionary
131
155
end
132
156
157
+ export threadedsensitivities
158
+
133
159
"""
134
160
Combine ODE solutions into a sensitivity solution
135
161
"""
136
- function generatesenssolution (sol, sensdict , sensprob)
162
+ function generatesenssolution (sol, senssolns , sensprob)
137
163
ts = sol. t
138
- ordkeys = sort ([x for x in keys (sensdict)])
139
- u = deepcopy (sol. u)
140
- for k in ordkeys
141
- for i in 1 : length (u)
142
- u[i] = vcat (u[i], sensdict[k](ts[i]))
143
- end
144
- end
164
+ u = [vcat (sol. u[i],(senssolns[k](ts[i]) for k in 1 : length (senssolns)). .. ) for i in 1 : length (sol. u)]
145
165
bigsol = build_solution (sensprob, sol. alg, ts, u;
146
166
interp= LinearInterpolation (ts, u), retcode= sol. retcode)
147
167
return bigsol
0 commit comments