Skip to content

Commit 4afec5a

Browse files
authored
Merge pull request #495 from SciML/ap/fix_le
Fix the formatting
2 parents 89ea6ee + 69cabe9 commit 4afec5a

File tree

3 files changed

+1444
-1441
lines changed

3 files changed

+1444
-1441
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,64 @@
1-
module LinearSolveCUDAExt
2-
3-
using CUDA
4-
using LinearSolve
5-
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
6-
using SciMLBase: AbstractSciMLOperator
7-
8-
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
9-
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
10-
if LinearSolve.cudss_loaded(A)
11-
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
12-
else
13-
if !LinearSolve.ALREADY_WARNED_CUDSS[]
14-
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
15-
LinearSolve.ALREADY_WARNED_CUDSS[] = true
16-
end
17-
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
18-
end
19-
end
20-
21-
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
22-
if !LinearSolve.CUDSS_LOADED[]
23-
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
24-
end
25-
nothing
26-
end
27-
28-
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
29-
kwargs...)
30-
if cache.isfresh
31-
fact = qr(CUDA.CuArray(cache.A))
32-
cache.cacheval = fact
33-
cache.isfresh = false
34-
end
35-
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
36-
cache.u .= y
37-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
38-
end
39-
40-
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
41-
maxiters::Int, abstol, reltol, verbose::Bool,
42-
assumptions::OperatorAssumptions)
43-
qr(CUDA.CuArray(A))
44-
end
45-
46-
function LinearSolve.init_cacheval(::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
47-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
48-
nothing
49-
end
50-
51-
function LinearSolve.init_cacheval(::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
52-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
53-
nothing
54-
end
55-
56-
function LinearSolve.init_cacheval(::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
57-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
58-
nothing
59-
end
60-
61-
end
1+
module LinearSolveCUDAExt
2+
3+
using CUDA
4+
using LinearSolve
5+
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
6+
using SciMLBase: AbstractSciMLOperator
7+
8+
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
9+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
10+
if LinearSolve.cudss_loaded(A)
11+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
12+
else
13+
if !LinearSolve.ALREADY_WARNED_CUDSS[]
14+
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
15+
LinearSolve.ALREADY_WARNED_CUDSS[] = true
16+
end
17+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
18+
end
19+
end
20+
21+
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
22+
if !LinearSolve.CUDSS_LOADED[]
23+
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
24+
end
25+
nothing
26+
end
27+
28+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
29+
kwargs...)
30+
if cache.isfresh
31+
fact = qr(CUDA.CuArray(cache.A))
32+
cache.cacheval = fact
33+
cache.isfresh = false
34+
end
35+
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
36+
cache.u .= y
37+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
38+
end
39+
40+
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
41+
maxiters::Int, abstol, reltol, verbose::Bool,
42+
assumptions::OperatorAssumptions)
43+
qr(CUDA.CuArray(A))
44+
end
45+
46+
function LinearSolve.init_cacheval(
47+
::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
48+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
49+
nothing
50+
end
51+
52+
function LinearSolve.init_cacheval(
53+
::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
54+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
55+
nothing
56+
end
57+
58+
function LinearSolve.init_cacheval(
59+
::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
60+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
61+
nothing
62+
end
63+
64+
end

0 commit comments

Comments
 (0)