|
1 |
| -module LinearSolveCUDAExt |
2 |
| - |
3 |
| -using CUDA |
4 |
| -using LinearSolve |
5 |
| -using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface |
6 |
| -using SciMLBase: AbstractSciMLOperator |
7 |
| - |
8 |
| -function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b, |
9 |
| - assump::OperatorAssumptions{Bool}) where {Tv, Ti} |
10 |
| - if LinearSolve.cudss_loaded(A) |
11 |
| - LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) |
12 |
| - else |
13 |
| - if !LinearSolve.ALREADY_WARNED_CUDSS[] |
14 |
| - @warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov") |
15 |
| - LinearSolve.ALREADY_WARNED_CUDSS[] = true |
16 |
| - end |
17 |
| - LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) |
18 |
| - end |
19 |
| -end |
20 |
| - |
21 |
| -function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR) |
22 |
| - if !LinearSolve.CUDSS_LOADED[] |
23 |
| - error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.") |
24 |
| - end |
25 |
| - nothing |
26 |
| -end |
27 |
| - |
28 |
| -function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization; |
29 |
| - kwargs...) |
30 |
| - if cache.isfresh |
31 |
| - fact = qr(CUDA.CuArray(cache.A)) |
32 |
| - cache.cacheval = fact |
33 |
| - cache.isfresh = false |
34 |
| - end |
35 |
| - y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b))) |
36 |
| - cache.u .= y |
37 |
| - SciMLBase.build_linear_solution(alg, y, nothing, cache) |
38 |
| -end |
39 |
| - |
40 |
| -function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr, |
41 |
| - maxiters::Int, abstol, reltol, verbose::Bool, |
42 |
| - assumptions::OperatorAssumptions) |
43 |
| - qr(CUDA.CuArray(A)) |
44 |
| -end |
45 |
| - |
46 |
| -function LinearSolve.init_cacheval(::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
47 |
| - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
48 |
| - nothing |
49 |
| -end |
50 |
| - |
51 |
| -function LinearSolve.init_cacheval(::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
52 |
| - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
53 |
| - nothing |
54 |
| -end |
55 |
| - |
56 |
| -function LinearSolve.init_cacheval(::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
57 |
| - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
58 |
| - nothing |
59 |
| -end |
60 |
| - |
61 |
| -end |
| 1 | +module LinearSolveCUDAExt |
| 2 | + |
| 3 | +using CUDA |
| 4 | +using LinearSolve |
| 5 | +using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface |
| 6 | +using SciMLBase: AbstractSciMLOperator |
| 7 | + |
| 8 | +function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b, |
| 9 | + assump::OperatorAssumptions{Bool}) where {Tv, Ti} |
| 10 | + if LinearSolve.cudss_loaded(A) |
| 11 | + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) |
| 12 | + else |
| 13 | + if !LinearSolve.ALREADY_WARNED_CUDSS[] |
| 14 | + @warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov") |
| 15 | + LinearSolve.ALREADY_WARNED_CUDSS[] = true |
| 16 | + end |
| 17 | + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) |
| 18 | + end |
| 19 | +end |
| 20 | + |
| 21 | +function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR) |
| 22 | + if !LinearSolve.CUDSS_LOADED[] |
| 23 | + error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.") |
| 24 | + end |
| 25 | + nothing |
| 26 | +end |
| 27 | + |
| 28 | +function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization; |
| 29 | + kwargs...) |
| 30 | + if cache.isfresh |
| 31 | + fact = qr(CUDA.CuArray(cache.A)) |
| 32 | + cache.cacheval = fact |
| 33 | + cache.isfresh = false |
| 34 | + end |
| 35 | + y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b))) |
| 36 | + cache.u .= y |
| 37 | + SciMLBase.build_linear_solution(alg, y, nothing, cache) |
| 38 | +end |
| 39 | + |
| 40 | +function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr, |
| 41 | + maxiters::Int, abstol, reltol, verbose::Bool, |
| 42 | + assumptions::OperatorAssumptions) |
| 43 | + qr(CUDA.CuArray(A)) |
| 44 | +end |
| 45 | + |
| 46 | +function LinearSolve.init_cacheval( |
| 47 | + ::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
| 48 | + Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 49 | + nothing |
| 50 | +end |
| 51 | + |
| 52 | +function LinearSolve.init_cacheval( |
| 53 | + ::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
| 54 | + Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 55 | + nothing |
| 56 | +end |
| 57 | + |
| 58 | +function LinearSolve.init_cacheval( |
| 59 | + ::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
| 60 | + Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 61 | + nothing |
| 62 | +end |
| 63 | + |
| 64 | +end |
0 commit comments