Skip to content

Commit 3913861

Browse files
committed
Make getting the knots MUCH faster, and add the ability to use averaged leading edge for knots instead of a maximum
1 parent d5f3ccd commit 3913861

File tree

3 files changed

+101
-14
lines changed

3 files changed

+101
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EpithelialDynamics1D"
22
uuid = "ace8a2d7-7779-48a6-a8a4-cf6831a7e55b"
33
authors = ["Daniel VandenHeuvel <danj.vandenheuvel@gmail.com>"]
4-
version = "1.3.3"
4+
version = "1.4.0"
55

66
[deps]
77
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"

src/statistics.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ function node_densities(cell_positions::AbstractVector{T}) where {T<:Number}
6060
end
6161

6262
"""
63-
get_knots(sol, num_knots = 500; indices = eachindex(sol))
63+
get_knots(sol, num_knots = 500; indices = eachindex(sol), use_max=true)
6464
6565
Computes knots for each time, covering the extremum of the cell positions across all
6666
cell simulations. You can restrict the simultaions to consider using the `indices`.
67+
If `use_max` is true, then the knots will be obtained by taking the extreme node positions
68+
for each `time`, otherwise the average is used.
6769
"""
68-
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol))
70+
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), use_extrema=true)
6971
@static if VERSION < v"1.7"
7072
knots = Vector{LinRange{Float64}}(undef, length(first(sol)))
7173
else
@@ -74,14 +76,30 @@ function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol))
7476
times = first(sol).t
7577
Base.Threads.@threads for i in eachindex(times)
7678
local a, b
77-
a = Inf
78-
b = -Inf
79+
if use_extrema
80+
a = Inf
81+
b = -Inf
82+
else
83+
a = 0.0
84+
b = 0.0
85+
ctr = 0
86+
end
7987
for j in indices
80-
for r in sol[j][i]
81-
a = min(a, r[begin])
82-
b = max(b, r[end])
88+
_a = sol[j][i][begin]
89+
_b = sol[j][i][end]
90+
if use_extrema
91+
a = min(a, _a)
92+
b = max(b, _b)
93+
else
94+
a += _a
95+
b += _b
96+
ctr += 1
8397
end
8498
end
99+
if !use_extrema
100+
a /= ctr
101+
b /= ctr
102+
end
85103
knots[i] = LinRange(a, b, num_knots)
86104
end
87105
return knots
@@ -104,7 +122,8 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
104122
# Keyword Arguments
105123
- `indices = eachindex(sol)`: The indices of the cell simulations to consider.
106124
- `num_knots::Int = 500`: The number of knots to use for the spline interpolation.
107-
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices)`: The knots to use for the spline interpolation.
125+
- `use_extrema::Bool = true`: Whether to use the extrema of the cell positions for the knots, or the average.
126+
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, use_extrema)`: The knots to use for the spline interpolation.
108127
- `alpha::Float64 = 0.05`: The significance level for the confidence intervals.
109128
- `interp_fnc = (u, t) -> LinearInterpolation{true}(u, t)`: The function to use for constructing the interpolant.
110129
@@ -116,12 +135,13 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
116135
- `uppers::Vector{Vector{Float64}}`: The upper bounds of the confidence intervals for the node densities for each cell simulation.
117136
- `knots::Vector{Vector{Float64}}`: The knots used for the spline interpolation.
118137
"""
119-
function node_densities(sol::EnsembleSolution;
120-
indices=eachindex(sol),
121-
num_knots=500,
122-
knots=get_knots(sol, num_knots; indices),
138+
function node_densities(sol::EnsembleSolution;
139+
indices=eachindex(sol),
140+
num_knots=500,
141+
use_extrema=true,
142+
knots=get_knots(sol, num_knots; indices, use_extrema),
123143
alpha=0.05,
124-
interp_fnc = (u, t) -> LinearInterpolation{true}(u, t))
144+
interp_fnc=(u, t) -> LinearInterpolation{true}(u, t))
125145
q = Vector{Vector{Vector{Float64}}}(undef, length(indices))
126146
r = Vector{Vector{Vector{Float64}}}(undef, length(indices))
127147
Base.Threads.@threads for i in eachindex(indices)

test/step_function.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,35 @@ end
316316
@test quantile(all_q, 0.975) uppers[j][i]
317317
end
318318
end
319+
320+
# Using average leading edge
321+
_indices = rand(eachindex(sol), 40)
322+
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false)
323+
@inferred node_densities(sol; indices=_indices, use_extrema=false)
324+
@test all((LinRange(0, 30, 500)), knots)
325+
for (enum_k, k) in enumerate(_indices)
326+
for j in rand(1:length(sol[k]), 40)
327+
for i in rand(1:length(sol[k][j]), 60)
328+
if i == 1
329+
@test q[enum_k][j][1] 1 / (r[enum_k][j][2] - r[enum_k][j][1])
330+
elseif i == length(sol[k][j])
331+
n = length(sol[k][j])
332+
@test q[enum_k][j][n] 1 / (r[enum_k][j][n] - r[enum_k][j][n-1])
333+
else
334+
@test q[enum_k][j][i] 2 / (r[enum_k][j][i+1] - r[enum_k][j][i-1])
335+
end
336+
@test r[enum_k][j][i] == sol[k][j][i]
337+
end
338+
end
339+
end
340+
for j in rand(1:length(fvm_sol), 50)
341+
for i in rand(1:length(knots[j]), 50)
342+
all_q = [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) for k in eachindex(_indices)]
343+
@test mean(all_q) means[j][i]
344+
@test quantile(all_q, 0.025) lowers[j][i]
345+
@test quantile(all_q, 0.975) uppers[j][i]
346+
end
347+
end
319348
end
320349

321350
@testset "Proliferation with a Moving Boundary" begin
@@ -519,4 +548,42 @@ end
519548
@test quantile(all_q, 0.975) uppers[j][i] rtol = 1e-3
520549
end
521550
end
551+
552+
# Using the average leading edge
553+
(; L) = leading_edges(sol)
554+
_L = stack(L)
555+
_indices = rand(eachindex(sol), 20)
556+
_L = _L[:, _indices]
557+
_mL = mean.(eachrow(_L))
558+
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false)
559+
@inferred node_densities(sol; indices=_indices, use_extrema=false)
560+
for j in eachindex(knots)
561+
a = mean(sol[k][j][begin] for k in _indices)
562+
b = mean(sol[k][j][end] for k in _indices)
563+
@test knots[j] LinRange(a, b, 500)
564+
@test knots[j][end] _mL[j]
565+
end
566+
for (enum_k, k) in enumerate(_indices)
567+
for j in rand(1:length(sol[k]), 40)
568+
for i in 1:length(sol[k][j])
569+
if i == 1
570+
@test q[enum_k][j][1] 1 / (r[enum_k][j][2] - r[enum_k][j][1])
571+
elseif i == length(sol[k][j])
572+
n = length(sol[k][j])
573+
@test q[enum_k][j][n] 1 / (r[enum_k][j][n] - r[enum_k][j][n-1])
574+
else
575+
@test q[enum_k][j][i] 2 / (r[enum_k][j][i+1] - r[enum_k][j][i-1])
576+
end
577+
@test r[enum_k][j][i] == sol[k][j][i]
578+
end
579+
end
580+
end
581+
for j in rand(eachindex(mb_sol), 40)
582+
for i in eachindex(knots[j])
583+
all_q = max.(0.0, [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) * (knots[j][i] r[k][j][end]) for k in eachindex(_indices)])
584+
@test mean(all_q) means[j][i] rtol = 1e-3
585+
@test quantile(all_q, 0.025) lowers[j][i] rtol = 1e-3
586+
@test quantile(all_q, 0.975) uppers[j][i] rtol = 1e-3
587+
end
588+
end
522589
end

0 commit comments

Comments
 (0)