Skip to content

Commit 508d3e0

Browse files
Merge pull request #161 from SciML/zero_init_safety
Tests for zero init safety
2 parents 6dd7023 + 31aed6b commit 508d3e0

File tree

4 files changed

+93
-88
lines changed

4 files changed

+93
-88
lines changed

.github/workflows/Downstream.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ jobs:
1717
os: [ubuntu-latest]
1818
package:
1919
- {user: SciML, repo: OrdinaryDiffEq.jl, group: InterfaceII}
20+
- {user: SciML, repo: ModelingToolkit.jl, group: All}
21+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core1}
2022

2123
steps:
2224
- uses: actions/checkout@v2

src/factorization.jl

Lines changed: 69 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -70,94 +70,6 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, A, b
7070
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A))
7171
end
7272

73-
# This could be a GenericFactorization perhaps?
74-
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
75-
reuse_symbolic::Bool = true
76-
end
77-
78-
function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol,
79-
verbose)
80-
A = convert(AbstractMatrix, A)
81-
zerobased = SparseArrays.getcolptr(A)[1] == 0
82-
res = SuiteSparse.UMFPACK.UmfpackLU(C_NULL, C_NULL, size(A, 1), size(A, 2),
83-
zerobased ? copy(SparseArrays.getcolptr(A)) :
84-
SuiteSparse.decrement(SparseArrays.getcolptr(A)),
85-
zerobased ? copy(rowvals(A)) :
86-
SuiteSparse.decrement(rowvals(A)),
87-
copy(nonzeros(A)), 0)
88-
finalizer(SuiteSparse.UMFPACK.umfpack_free_symbolic, res)
89-
res
90-
end
91-
92-
function do_factorization(::UMFPACKFactorization, A, b, u)
93-
A = convert(AbstractMatrix, A)
94-
if A isa SparseMatrixCSC
95-
return lu(A)
96-
else
97-
error("Sparse LU is not defined for $(typeof(A))")
98-
end
99-
end
100-
101-
function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
102-
A = cache.A
103-
A = convert(AbstractMatrix, A)
104-
if cache.isfresh
105-
if cache.cacheval !== nothing && alg.reuse_symbolic
106-
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
107-
# This won't recompute if it does.
108-
SuiteSparse.UMFPACK.umfpack_symbolic!(cache.cacheval)
109-
fact = lu!(cache.cacheval, A)
110-
else
111-
fact = do_factorization(alg, A, cache.b, cache.u)
112-
end
113-
cache = set_cacheval(cache, fact)
114-
end
115-
116-
y = ldiv!(cache.u, cache.cacheval, cache.b)
117-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
118-
end
119-
120-
Base.@kwdef struct KLUFactorization <: AbstractFactorization
121-
reuse_symbolic::Bool = true
122-
end
123-
124-
function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol,
125-
verbose)
126-
return KLU.KLUFactorization(convert(AbstractMatrix, A)) # this takes care of the copy internally.
127-
end
128-
129-
function do_factorization(::KLUFactorization, A, b, u)
130-
A = convert(AbstractMatrix, A)
131-
if A isa SparseMatrixCSC
132-
return klu(A)
133-
else
134-
error("KLU is not defined for $(typeof(A))")
135-
end
136-
end
137-
138-
function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
139-
A = cache.A
140-
A = convert(AbstractMatrix, A)
141-
if cache.isfresh
142-
if cache.cacheval !== nothing && alg.reuse_symbolic
143-
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
144-
# This won't recompute if it does.
145-
KLU.klu_analyze!(cache.cacheval)
146-
copyto!(cache.cacheval.nzval, A.nzval)
147-
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
148-
KLU.klu_factor!(cache.cacheval)
149-
end
150-
fact = KLU.klu!(cache.cacheval, A)
151-
else
152-
fact = do_factorization(alg, A, cache.b, cache.u)
153-
end
154-
cache = set_cacheval(cache, fact)
155-
end
156-
157-
y = ldiv!(cache.u, cache.cacheval, cache.b)
158-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
159-
end
160-
16173
## QRFactorization
16274

16375
struct QRFactorization{P} <: AbstractFactorization
@@ -327,6 +239,75 @@ function init_cacheval(alg::Union{GenericFactorization,
327239
do_factorization(alg, newA, b, u)
328240
end
329241

242+
################################## Factorizations which require solve overloads
243+
244+
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
245+
reuse_symbolic::Bool = true
246+
end
247+
248+
function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol,
249+
verbose)
250+
A = convert(AbstractMatrix, A)
251+
zerobased = SparseArrays.getcolptr(A)[1] == 0
252+
res = SuiteSparse.UMFPACK.UmfpackLU(C_NULL, C_NULL, size(A, 1), size(A, 2),
253+
zerobased ? copy(SparseArrays.getcolptr(A)) :
254+
SuiteSparse.decrement(SparseArrays.getcolptr(A)),
255+
zerobased ? copy(rowvals(A)) :
256+
SuiteSparse.decrement(rowvals(A)),
257+
copy(nonzeros(A)), 0)
258+
finalizer(SuiteSparse.UMFPACK.umfpack_free_symbolic, res)
259+
res
260+
end
261+
262+
function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
263+
A = cache.A
264+
A = convert(AbstractMatrix, A)
265+
if cache.isfresh
266+
if cache.cacheval !== nothing && alg.reuse_symbolic
267+
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
268+
fact = lu!(cache.cacheval, A)
269+
else
270+
fact = do_factorization(alg, A, cache.b, cache.u)
271+
end
272+
cache = set_cacheval(cache, fact)
273+
end
274+
275+
y = ldiv!(cache.u, cache.cacheval, cache.b)
276+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
277+
end
278+
279+
Base.@kwdef struct KLUFactorization <: AbstractFactorization
280+
reuse_symbolic::Bool = true
281+
end
282+
283+
function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol,
284+
verbose)
285+
return KLU.KLUFactorization(convert(AbstractMatrix, A)) # this takes care of the copy internally.
286+
end
287+
288+
function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
289+
A = cache.A
290+
A = convert(AbstractMatrix, A)
291+
if cache.isfresh
292+
if cache.cacheval !== nothing && alg.reuse_symbolic
293+
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
294+
# This won't recompute if it does.
295+
KLU.klu_analyze!(cache.cacheval)
296+
copyto!(cache.cacheval.nzval, A.nzval)
297+
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
298+
KLU.klu_factor!(cache.cacheval)
299+
end
300+
fact = KLU.klu!(cache.cacheval, A)
301+
else
302+
fact = do_factorization(alg, A, cache.b, cache.u)
303+
end
304+
cache = set_cacheval(cache, fact)
305+
end
306+
307+
y = ldiv!(cache.u, cache.cacheval, cache.b)
308+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
309+
end
310+
330311
## RFLUFactorization
331312

332313
struct RFLUFactorization{P, T} <: AbstractFactorization

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ end
2020

2121
if GROUP == "All" || GROUP == "Core"
2222
@time @safetestset "Basic Tests" begin include("basictests.jl") end
23+
@time @safetestset "Zero Initialization Tests" begin include("zeroinittests.jl") end
2324
end
2425

2526
if GROUP == "LinearSolveCUDA"

test/zeroinittests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using LinearSolve, LinearAlgebra, SparseArrays, Test
2+
3+
A = Diagonal(ones(4))
4+
b = rand(4)
5+
A = sparse(A)
6+
Anz = deepcopy(A)
7+
A.nzval .= 0
8+
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
9+
10+
function test_nonzero_init(alg = nothing)
11+
linprob = LinearProblem(A, b)
12+
13+
cache = init(linprob, alg)
14+
cache = LinearSolve.set_A(cache, Anz)
15+
sol = solve(cache; cache_kwargs...)
16+
@test sol.u == b
17+
end
18+
19+
test_nonzero_init()
20+
test_nonzero_init(KLUFactorization())
21+
test_nonzero_init(UMFPACKFactorization())

0 commit comments

Comments
 (0)