Skip to content

Commit aa61ed7

Browse files
committed
Parallel is off by default
1 parent e7b1a5d commit aa61ed7

File tree

3 files changed

+87
-37
lines changed

3 files changed

+87
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EpithelialDynamics1D"
22
uuid = "ace8a2d7-7779-48a6-a8a4-cf6831a7e55b"
33
authors = ["Daniel VandenHeuvel <danj.vandenheuvel@gmail.com>"]
4-
version = "1.8.0"
4+
version = "1.8.1"
55

66
[deps]
77
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"

src/statistics.jl

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function node_densities(cell_positions::AbstractVector{T}; smooth_boundary=true)
7979
end
8080

8181
"""
82-
get_knots(sol, num_knots = 500; indices = eachindex(sol), stat=maximum)
82+
get_knots(sol, num_knots = 500; indices = eachindex(sol), stat=maximum, parallel=false)
8383
8484
Computes knots for each time, covering the extremum of the cell positions across all
8585
cell simulations. You can restrict the simultaions to consider using the `indices`.
@@ -88,7 +88,7 @@ to the vector of extrema at each time. For example, if `stat=maximum` then, at e
8888
the knots range between the smallest position observed and the maximum position
8989
observed across each simulation at that time.
9090
"""
91-
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), stat=(minimum, maximum))
91+
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), stat=(minimum, maximum), parallel=false)
9292
if stat isa Function
9393
stat = (stat, stat)
9494
end
@@ -98,15 +98,27 @@ function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol),
9898
knots = Vector{LinRange{Float64,Int}}(undef, length(first(sol)))
9999
end
100100
times = first(sol).t
101-
Base.Threads.@threads for i in eachindex(times)
102-
local a, b
103-
a = zeros(length(indices))
104-
b = zeros(length(indices))
105-
for (ℓ, j) in enumerate(indices)
106-
a[ℓ] = sol[j].u[i][begin]
107-
b[ℓ] = sol[j].u[i][end]
101+
if parallel
102+
Base.Threads.@threads for i in eachindex(times)
103+
local a, b
104+
a = zeros(length(indices))
105+
b = zeros(length(indices))
106+
for (ℓ, j) in enumerate(indices)
107+
a[ℓ] = sol[j].u[i][begin]
108+
b[ℓ] = sol[j].u[i][end]
109+
end
110+
knots[i] = LinRange(stat[1](a), stat[2](b), num_knots)
111+
end
112+
else
113+
for i in eachindex(times)
114+
a = zeros(length(indices))
115+
b = zeros(length(indices))
116+
for (ℓ, j) in enumerate(indices)
117+
a[ℓ] = sol[j].u[i][begin]
118+
b[ℓ] = sol[j].u[i][end]
119+
end
120+
knots[i] = LinRange(stat[1](a), stat[2](b), num_knots)
108121
end
109-
knots[i] = LinRange(stat[1](a), stat[2](b), num_knots)
110122
end
111123
return knots
112124
end
@@ -137,7 +149,8 @@ negative values are set to zero.
137149
- `indices = eachindex(sol)`: The indices of the cell simulations to consider.
138150
- `num_knots::Int = 500`: The number of knots to use for the spline interpolation.
139151
- `stat = (minimum, maximum)`: How to summarise the knots for `get_knots`.
140-
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, stat)`: The knots to use for the spline interpolation.
152+
- `parallel = false`: Whether to use multithreading for the loops.
153+
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, stat, parallel)`: The knots to use for the spline interpolation.
141154
- `alpha::Float64 = 0.05`: The significance level for the confidence intervals.
142155
- `interp_fnc = (u, t) -> LinearInterpolation{true}(u, t)`: The function to use for constructing the interpolant.
143156
- `smooth_boundary::Bool = true`: Whether to use the smooth boundary node densities.
@@ -155,44 +168,81 @@ function node_densities(sol::EnsembleSolution;
155168
indices=eachindex(sol),
156169
num_knots=500,
157170
stat=(minimum, maximum),
171+
parallel=false,
158172
knots=get_knots(sol, num_knots; indices, stat),
159173
alpha=0.05,
160174
interp_fnc=(u, t) -> LinearInterpolation{true}(u, t),
161175
smooth_boundary=true,
162176
extrapolate=false)
163177
q = Vector{Vector{Vector{Float64}}}(undef, length(indices))
164178
r = Vector{Vector{Vector{Float64}}}(undef, length(indices))
165-
Base.Threads.@threads for i in eachindex(indices)
166-
q[i] = node_densities.(sol[indices[i]].u; smooth_boundary)
167-
r[i] = sol[indices[i]].u
179+
if parallel
180+
Base.Threads.@threads for i in eachindex(indices)
181+
q[i] = node_densities.(sol[indices[i]].u; smooth_boundary)
182+
r[i] = sol[indices[i]].u
183+
end
184+
else
185+
for i in eachindex(indices)
186+
q[i] = node_densities.(sol[indices[i]].u; smooth_boundary)
187+
r[i] = sol[indices[i]].u
188+
end
168189
end
169190
nt = length(first(sol))
170191
nsims = length(indices)
171192
q_splines = zeros(num_knots, nt, nsims)
172193
q_means = [zeros(num_knots) for _ in 1:nt]
173194
q_lowers = [zeros(num_knots) for _ in 1:nt]
174195
q_uppers = [zeros(num_knots) for _ in 1:nt]
175-
Base.Threads.@threads for k in 1:nsims
176-
for j in 1:nt
177-
densities = q[k][j]
178-
cell_positions = r[k][j]
179-
interp = interp_fnc(densities, cell_positions)
180-
for i in eachindex(knots[j])
181-
if !extrapolate && knots[j][i] > r[k][j][end]
182-
q_splines[i, j, k] = 0.0
183-
else
184-
q_splines[i, j, k] = max(0.0, interp(knots[j][i]))
196+
if parallel
197+
Base.Threads.@threads for k in 1:nsims
198+
for j in 1:nt
199+
densities = q[k][j]
200+
cell_positions = r[k][j]
201+
interp = interp_fnc(densities, cell_positions)
202+
for i in eachindex(knots[j])
203+
if !extrapolate && knots[j][i] > r[k][j][end]
204+
q_splines[i, j, k] = 0.0
205+
else
206+
q_splines[i, j, k] = max(0.0, interp(knots[j][i]))
207+
end
208+
end
209+
end
210+
end
211+
else
212+
for k in 1:nsims
213+
for j in 1:nt
214+
densities = q[k][j]
215+
cell_positions = r[k][j]
216+
interp = interp_fnc(densities, cell_positions)
217+
for i in eachindex(knots[j])
218+
if !extrapolate && knots[j][i] > r[k][j][end]
219+
q_splines[i, j, k] = 0.0
220+
else
221+
q_splines[i, j, k] = max(0.0, interp(knots[j][i]))
222+
end
185223
end
186224
end
187225
end
188226
end
189-
Base.Threads.@threads for j in 1:nt
190-
knot_range = knots[j]
191-
for i in eachindex(knot_range)
192-
q_values = @views q_splines[i, j, :]
193-
q_means[j][i] = mean(q_values)
194-
q_lowers[j][i] = quantile(q_values, alpha / 2)
195-
q_uppers[j][i] = quantile(q_values, 1 - alpha / 2)
227+
if parallel
228+
Base.Threads.@threads for j in 1:nt
229+
knot_range = knots[j]
230+
for i in eachindex(knot_range)
231+
q_values = @views q_splines[i, j, :]
232+
q_means[j][i] = mean(q_values)
233+
q_lowers[j][i] = quantile(q_values, alpha / 2)
234+
q_uppers[j][i] = quantile(q_values, 1 - alpha / 2)
235+
end
236+
end
237+
else
238+
for j in 1:nt
239+
knot_range = knots[j]
240+
for i in eachindex(knot_range)
241+
q_values = @views q_splines[i, j, :]
242+
q_means[j][i] = mean(q_values)
243+
q_lowers[j][i] = quantile(q_values, alpha / 2)
244+
q_uppers[j][i] = quantile(q_values, 1 - alpha / 2)
245+
end
196246
end
197247
end
198248
return (q=q, r=r, means=q_means, lowers=q_lowers, uppers=q_uppers, knots=knots)

test/step_function.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ end
316316
# Test the statistics with a specific interpolation function
317317
_indices = eachindex(sol)
318318
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, interp_fnc=CubicSpline)
319-
@inferred node_densities(sol; indices=_indices, interp_fnc=CubicSpline)
319+
@inferred node_densities(sol; indices=_indices, interp_fnc=CubicSpline, parallel=true)
320320
@test all((LinRange(0, 30, 500)), knots)
321321
for (enum_k, k) in enumerate(_indices)
322322
for j in rand(1:length(sol[k]), 40)
@@ -543,7 +543,7 @@ end
543543

544544
# Test the statistics with a specific interpolation function
545545
_indices = rand(eachindex(sol), 20)
546-
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, interp_fnc=CubicSpline, stat = minimum)
546+
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, interp_fnc=CubicSpline, stat=minimum)
547547
@inferred node_densities(sol; indices=_indices, interp_fnc=CubicSpline)
548548
for j in eachindex(knots)
549549
a = Inf
@@ -583,7 +583,7 @@ end
583583
_L = _L[:, _indices]
584584
_mL = mean.(eachrow(_L))
585585
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, stat=mean)
586-
@inferred node_densities(sol; indices=_indices, stat=mean)
586+
@inferred node_densities(sol; indices=_indices, stat=mean, parallel=true)
587587
for j in eachindex(knots)
588588
a = mean(sol[k][j][begin] for k in _indices)
589589
b = mean(sol[k][j][end] for k in _indices)
@@ -621,7 +621,7 @@ end
621621
_L = _L[:, _indices]
622622
_mL = mean.(eachrow(_L))
623623
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, stat=mean, extrapolate=true)
624-
@inferred node_densities(sol; indices=_indices, stat=mean,extrapolate=true)
624+
@inferred node_densities(sol; indices=_indices, stat=mean, extrapolate=true, parallel=true)
625625
for j in eachindex(knots)
626626
a = mean(sol[k][j][begin] for k in _indices)
627627
b = mean(sol[k][j][end] for k in _indices)
@@ -645,7 +645,7 @@ end
645645
end
646646
for j in rand(eachindex(mb_sol), 40)
647647
for i in eachindex(knots[j])
648-
all_q = max.(0.0, [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) for k in eachindex(_indices)])
648+
all_q = max.(0.0, [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) for k in eachindex(_indices)])
649649
@test mean(all_q) means[j][i] rtol = 1e-3
650650
@test quantile(all_q, 0.025) lowers[j][i] rtol = 1e-3
651651
@test quantile(all_q, 0.975) uppers[j][i] rtol = 1e-3

0 commit comments

Comments
 (0)