Skip to content

Commit 5115d34

Browse files
authored
Merge pull request #293 from ErikQQY/qqy/safe_similar
Use safe similar
2 parents 71fa0f9 + b12dea5 commit 5115d34

File tree

12 files changed

+29
-28
lines changed

12 files changed

+29
-28
lines changed

lib/BoundaryValueDiffEqCore/src/utils.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,14 @@ function __get_bcresid_prototype(::StandardSecondOrderBVProblem, prob::BVProblem
271271
return prototype, size(prototype)
272272
end
273273

274-
@inline function __similar(x, args...)
274+
@inline function safe_similar(x::AbstractArray{<:T}, args...) where {T <: Number}
275275
y = similar(x, args...)
276-
return zero(y)
276+
fill!(y, T(0))
277+
return y
277278
end
278279

279280
@inline function __fill_like(v, x, args...)
280-
y = __similar(x, args...)
281+
y = similar(x, args...)
281282
fill!(y, v)
282283
return y
283284
end

lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorith
1212
eval_bc_residual!, get_tmp, __maybe_matmul!, __resize!,
1313
__extract_problem_details, __initial_guess,
1414
__default_coloring_algorithm, __maybe_allocate_diffcache,
15-
__restructure_sol, __get_bcresid_prototype, __similar, __vec,
16-
__vec_f, __vec_f!, __vec_bc, __vec_bc!,
15+
__restructure_sol, __get_bcresid_prototype, safe_similar,
16+
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
1717
recursive_flatten_twopoint!, __internal_nlsolve_problem,
1818
MaybeDiffCache, __extract_mesh, __extract_u0,
1919
__has_initial_guess, __initial_guess_length,

lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ After we construct an interpolant, we use interp_eval to evaluate it.
1717
(; f, M, stage, p, ITU) = cache
1818
(; q_coeff) = ITU
1919

20-
K = __similar(cache.y[1].du, M, stage)
20+
K = safe_similar(cache.y[1].du, M, stage)
2121

2222
ctr_y = (j - 1) * (stage + 1) + 1
2323

lib/BoundaryValueDiffEqFIRK/src/collocation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ end
123123
@views function Φ(
124124
fᵢ_cache, k_discrete, f, TU::FIRKTableau{false}, y, u, p, mesh, mesh_dt, stage::Int)
125125
(; c, a, b) = TU
126-
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
126+
residuals = [safe_similar(yᵢ) for yᵢ in y[1:(end - 1)]]
127127
tmp1 = get_tmp(fᵢ_cache, u)
128128
K = get_tmp(k_discrete[1], u) # Not optimal # TODO
129129
T = eltype(u)
@@ -161,7 +161,7 @@ end
161161
(; b) = TU
162162
(; nest_prob, alg, nest_tol) = cache
163163

164-
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
164+
residuals = [safe_similar(yᵢ) for yᵢ in y[1:(end - 1)]]
165165

166166
T = eltype(u)
167167
nestprob_p = vcat(T(mesh[1]), T(mesh_dt[1]), get_tmp(y[1], u))

lib/BoundaryValueDiffEqFIRK/src/firk.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-3
117117
stage = alg_stage(alg)
118118

119119
k_discrete = [__maybe_allocate_diffcache(
120-
__similar(X, M, stage), chunksize, alg.jac_alg) for _ in 1:Nig]
120+
safe_similar(X, M, stage), chunksize, alg.jac_alg) for _ in 1:Nig]
121121

122122
bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X)
123123

@@ -131,7 +131,7 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-3
131131
nothing
132132
end
133133

134-
defect = VectorOfArray([__similar(X, ifelse(adaptive, M, 0)) for _ in 1:Nig])
134+
defect = VectorOfArray([safe_similar(X, ifelse(adaptive, M, 0)) for _ in 1:Nig])
135135

136136
# Transform the functions to handle non-vector inputs
137137
bcresid_prototype = __vec(bcresid_prototype)
@@ -212,7 +212,7 @@ function init_expanded(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e
212212
y = __alloc.(copy.(y₀.u)) # Runtime dispatch
213213

214214
k_discrete = [__maybe_allocate_diffcache(
215-
__similar(X, M, stage), chunksize, alg.jac_alg) for _ in 1:Nig] # Runtime dispatch
215+
safe_similar(X, M, stage), chunksize, alg.jac_alg) for _ in 1:Nig] # Runtime dispatch
216216

217217
bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X)
218218

@@ -436,7 +436,7 @@ function __construct_nlproblem(
436436

437437
resid_bc = cache.bcresid_prototype
438438
L = length(resid_bc)
439-
resid_collocation = __similar(y, cache.M * (N - 1) * (stage + 1))
439+
resid_collocation = safe_similar(y, cache.M * (N - 1) * (stage + 1))
440440

441441
cache_bc = if iip
442442
DI.prepare_jacobian(loss_bc, resid_bc, bc_diffmode, y, Constant(cache.p))
@@ -517,7 +517,7 @@ function __construct_nlproblem(
517517
(; stage) = cache
518518
N = length(cache.mesh)
519519

520-
resid_collocation = __similar(y, cache.M * (N - 1) * (stage + 1))
520+
resid_collocation = safe_similar(y, cache.M * (N - 1) * (stage + 1))
521521

522522
resid = vcat(
523523
@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), resid_collocation,
@@ -576,7 +576,7 @@ function __construct_nlproblem(
576576
N = length(cache.mesh)
577577
resid_bc = cache.bcresid_prototype
578578
L = length(resid_bc)
579-
resid_collocation = __similar(y, cache.M * (N - 1))
579+
resid_collocation = safe_similar(y, cache.M * (N - 1))
580580

581581
cache_bc = if iip
582582
DI.prepare_jacobian(loss_bc, resid_bc, bc_diffmode, y, Constant(cache.p))
@@ -654,7 +654,7 @@ function __construct_nlproblem(
654654
N = length(cache.mesh)
655655

656656
resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
657-
__similar(y, cache.M * (N - 1)),
657+
safe_similar(y, cache.M * (N - 1)),
658658
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
659659

660660
diffmode = if jac_alg.diffmode isa AutoSparse

lib/BoundaryValueDiffEqFIRK/src/interpolation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheExpand}
139139
# Quick handle for the case where tval is at the boundary
140140
(tval == t[1]) && return first(u)
141141
(tval == t[end]) && return last(u)
142-
K = __similar(first(u), length(first(u)), stage)
142+
K = safe_similar(first(u), length(first(u)), stage)
143143
j = interval(t, tval)
144144
ctr_y = (j - 1) * (stage + 1) + 1
145145

lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorith
1212
eval_bc_residual!, get_tmp, __maybe_matmul!, __resize!,
1313
__extract_problem_details, __initial_guess,
1414
__maybe_allocate_diffcache, __restructure_sol,
15-
__get_bcresid_prototype, __similar, __vec, __vec_f, __vec_f!,
16-
__vec_bc, __vec_bc!, recursive_flatten_twopoint!,
15+
__get_bcresid_prototype, safe_similar, __vec, __vec_f,
16+
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
1717
__internal_nlsolve_problem, __extract_mesh, __extract_u0,
1818
__has_initial_guess, __initial_guess_length,
1919
__initial_guess_on_mesh, __flatten_initial_guess,

lib/BoundaryValueDiffEqMIRK/src/collocation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737
@views function Φ(
3838
fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int)
3939
(; c, v, x, b) = TU
40-
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
40+
residuals = [safe_similar(yᵢ) for yᵢ in y[1:(end - 1)]]
4141
tmp = get_tmp(fᵢ_cache, u)
4242
T = eltype(u)
4343
for i in eachindex(k_discrete)

lib/BoundaryValueDiffEqMIRK/src/mirk.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, abstol =
5555
stage = alg_stage(alg)
5656

5757
k_discrete = [__maybe_allocate_diffcache(
58-
__similar(X, N, stage), chunksize, alg.jac_alg) for _ in 1:Nig]
58+
safe_similar(X, N, stage), chunksize, alg.jac_alg) for _ in 1:Nig]
5959
k_interp = VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig])
6060

6161
bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X)
@@ -321,7 +321,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
321321

322322
resid_bc = cache.bcresid_prototype
323323
L = length(resid_bc)
324-
resid_collocation = __similar(y, cache.M * (N - 1))
324+
resid_collocation = safe_similar(y, cache.M * (N - 1))
325325

326326
cache_bc = if iip
327327
DI.prepare_jacobian(loss_bc, resid_bc, bc_diffmode, y, Constant(cache.p))
@@ -440,7 +440,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
440440
N = length(cache.mesh)
441441

442442
resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
443-
__similar(y, cache.M * (N - 1)),
443+
safe_similar(y, cache.M * (N - 1)),
444444
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
445445
L = length(cache.bcresid_prototype)
446446

lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorith
1111
eval_bc_residual!, get_tmp, __maybe_matmul!,
1212
__extract_problem_details, __initial_guess,
1313
__maybe_allocate_diffcache, __restructure_sol,
14-
__get_bcresid_prototype, __similar, __vec, __vec_f, __vec_f!,
15-
__vec_bc, __vec_bc!, __vec_so_bc!, __vec_so_bc,
14+
__get_bcresid_prototype, safe_similar, __vec, __vec_f,
15+
__vec_f!, __vec_bc, __vec_bc!, __vec_so_bc!, __vec_so_bc,
1616
recursive_flatten_twopoint!, __internal_nlsolve_problem,
1717
__extract_mesh, __extract_u0, __has_initial_guess,
1818
__initial_guess_length, __initial_guess_on_mesh,

0 commit comments

Comments
 (0)