Skip to content

Commit 4e9513b

Browse files
authored
[CUSOLVER] Interface gesv! and gels! (#2406)
1 parent 19a08ef commit 4e9513b

File tree

9 files changed

+347
-65
lines changed

9 files changed

+347
-65
lines changed

lib/cusolver/CUSOLVER.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include("libcusolverMg.jl")
3333
include("libcusolverRF.jl")
3434

3535
# low-level wrappers
36+
include("helpers.jl")
3637
include("error.jl")
3738
include("base.jl")
3839
include("sparse.jl")

lib/cusolver/base.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,63 @@ function Base.convert(::Type{cusolverDirectMode_t}, direct::Char)
6363
throw(ArgumentError("Unknown direction mode $direct."))
6464
end
6565
end
66+
67+
function Base.convert(::Type{cusolverIRSRefinement_t}, irs::String)
68+
if irs == "NOT_SET"
69+
CUSOLVER_IRS_REFINE_NOT_SET
70+
elseif irs == "NONE"
71+
CUSOLVER_IRS_REFINE_NONE
72+
elseif irs == "CLASSICAL"
73+
CUSOLVER_IRS_REFINE_CLASSICAL
74+
elseif "CLASSICAL_GMRES"
75+
CUSOLVER_IRS_REFINE_CLASSICAL_GMRES
76+
elseif "GMRES"
77+
CUSOLVER_IRS_REFINE_GMRES
78+
elseif "GMRES_GMRES"
79+
CUSOLVER_IRS_REFINE_GMRES_GMRES
80+
elseif "GMRES_NOPCOND"
81+
CUSOLVER_IRS_REFINE_GMRES_NOPCOND
82+
else
83+
throw(ArgumentError("Unknown iterative refinement solver $irs."))
84+
end
85+
end
86+
87+
function Base.convert(::Type{cusolverPrecType_t}, T::String)
88+
if T == "R_16F"
89+
return CUSOLVER_R_16F
90+
elseif T == "R_16BF"
91+
return CUSOLVER_R_16BF
92+
elseif T == "R_TF32"
93+
return CUSOLVER_R_TF32
94+
elseif T == "R_32F"
95+
return CUSOLVER_R_32F
96+
elseif T == "R_64F"
97+
return CUSOLVER_R_64F
98+
elseif T == "C_16F"
99+
return CUSOLVER_C_16F
100+
elseif T == "C_16BF"
101+
return CUSOLVER_C_16BF
102+
elseif T == "C_TF32"
103+
return CUSOLVER_C_TF32
104+
elseif T == "C_32F"
105+
return CUSOLVER_C_32F
106+
elseif T == "C_64F"
107+
return CUSOLVER_C_64F
108+
else
109+
throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!"))
110+
end
111+
end
112+
113+
function Base.convert(::Type{cusolverPrecType_t}, T::DataType)
114+
if T === Float32
115+
return CUSOLVER_R_32F
116+
elseif T === Float64
117+
return CUSOLVER_R_64F
118+
elseif T === Complex{Float32}
119+
return CUSOLVER_C_32F
120+
elseif T === Complex{Float64}
121+
return CUSOLVER_C_64F
122+
else
123+
throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!"))
124+
end
125+
end

lib/cusolver/dense.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,114 @@ for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32),
884884
end
885885
end
886886

887+
# gesv
888+
function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true,
889+
residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL",
890+
maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat
891+
892+
params = CuSolverIRSParameters()
893+
info = CuSolverIRSInformation()
894+
n = checksquare(A)
895+
nrhs = size(B, 2)
896+
lda = max(1, stride(A, 2))
897+
ldb = max(1, stride(B, 2))
898+
ldx = max(1, stride(X, 2))
899+
niters = Ref{Cint}()
900+
dh = dense_handle()
901+
902+
if irs_precision == "AUTO"
903+
(T == Float32) && (irs_precision = "R_32F")
904+
(T == Float64) && (irs_precision = "R_64F")
905+
(T == ComplexF32) && (irs_precision = "C_32F")
906+
(T == ComplexF64) && (irs_precision = "C_64F")
907+
else
908+
(T == Float32) && (irs_precision ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
909+
(T == Float64) && (irs_precision ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
910+
(T == ComplexF32) && (irs_precision ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
911+
(T == ComplexF64) && (irs_precision ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
912+
end
913+
cusolverDnIRSParamsSetSolverMainPrecision(params, T)
914+
cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision)
915+
cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver)
916+
(tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol)
917+
(tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner)
918+
(maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters)
919+
(maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner)
920+
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
921+
residual_history && cusolverDnIRSInfosRequestResidual(info)
922+
923+
function bufferSize()
924+
buffer_size = Ref{Csize_t}(0)
925+
cusolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size)
926+
return buffer_size[]
927+
end
928+
929+
with_workspace(dh.workspace_gpu, bufferSize) do buffer
930+
cusolverDnIRSXgesv(dh, params, info, n, nrhs, A, lda, B, ldb,
931+
X, ldx, buffer, sizeof(buffer), niters, dh.info)
932+
end
933+
934+
# Copy the solver flag and delete the device memory
935+
flag = @allowscalar dh.info[1]
936+
chklapackerror(flag |> BlasInt)
937+
938+
return X, info
939+
end
940+
941+
# gels
942+
function gels!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true,
943+
residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL",
944+
maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat
945+
946+
params = CuSolverIRSParameters()
947+
info = CuSolverIRSInformation()
948+
m,n = size(A)
949+
nrhs = size(B, 2)
950+
lda = max(1, stride(A, 2))
951+
ldb = max(1, stride(B, 2))
952+
ldx = max(1, stride(X, 2))
953+
niters = Ref{Cint}()
954+
dh = dense_handle()
955+
956+
if irs_precision == "AUTO"
957+
(T == Float32) && (irs_precision = "R_32F")
958+
(T == Float64) && (irs_precision = "R_64F")
959+
(T == ComplexF32) && (irs_precision = "C_32F")
960+
(T == ComplexF64) && (irs_precision = "C_64F")
961+
else
962+
(T == Float32) && (irs_precision ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
963+
(T == Float64) && (irs_precision ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
964+
(T == ComplexF32) && (irs_precision ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
965+
(T == ComplexF64) && (irs_precision ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
966+
end
967+
cusolverDnIRSParamsSetSolverMainPrecision(params, T)
968+
cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision)
969+
cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver)
970+
(tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol)
971+
(tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner)
972+
(maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters)
973+
(maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner)
974+
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
975+
residual_history && cusolverDnIRSInfosRequestResidual(info)
976+
977+
function bufferSize()
978+
buffer_size = Ref{Csize_t}(0)
979+
cusolverDnIRSXgels_bufferSize(dh, params, m, n, nrhs, buffer_size)
980+
return buffer_size[]
981+
end
982+
983+
with_workspace(dh.workspace_gpu, bufferSize) do buffer
984+
cusolverDnIRSXgels(dh, params, info, m, n, nrhs, A, lda, B, ldb,
985+
X, ldx, buffer, sizeof(buffer), niters, dh.info)
986+
end
987+
988+
# Copy the solver flag and delete the device memory
989+
flag = @allowscalar dh.info[1]
990+
chklapackerror(flag |> BlasInt)
991+
992+
return X, info
993+
end
994+
887995
# LAPACK
888996
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
889997
@eval begin

lib/cusolver/dense_generic.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
mutable struct CuSolverParameters
2-
parameters::cusolverDnParams_t
3-
4-
function CuSolverParameters()
5-
parameters_ref = Ref{cusolverDnParams_t}()
6-
cusolverDnCreateParams(parameters_ref)
7-
obj = new(parameters_ref[])
8-
finalizer(cusolverDnDestroyParams, obj)
9-
obj
10-
end
11-
end
12-
13-
Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters
14-
151
# Xpotrf
162
function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
173
chkuplo(uplo)

lib/cusolver/error.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ function description(err)
2929
"an internal operation failed"
3030
elseif err.code == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED
3131
"the matrix type is not supported."
32+
elseif err.code == CUSOLVER_STATUS_NOT_SUPPORTED
33+
"the parameter combination is not supported."
3234
else
3335
"no description for this error"
3436
end

lib/cusolver/helpers.jl

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# cuSOLVER helper functions
2+
3+
## SparseQRInfo
4+
5+
mutable struct SparseQRInfo
6+
info::csrqrInfo_t
7+
8+
function SparseQRInfo()
9+
info_ref = Ref{csrqrInfo_t}()
10+
cusolverSpCreateCsrqrInfo(info_ref)
11+
obj = new(info_ref[])
12+
finalizer(cusolverSpDestroyCsrqrInfo, obj)
13+
obj
14+
end
15+
end
16+
17+
Base.unsafe_convert(::Type{csrqrInfo_t}, info::SparseQRInfo) = info.info
18+
19+
20+
## SparseCholeskyInfo
21+
22+
mutable struct SparseCholeskyInfo
23+
info::csrcholInfo_t
24+
25+
function SparseCholeskyInfo()
26+
info_ref = Ref{csrcholInfo_t}()
27+
cusolverSpCreateCsrcholInfo(info_ref)
28+
obj = new(info_ref[])
29+
finalizer(cusolverSpDestroyCsrcholInfo, obj)
30+
obj
31+
end
32+
end
33+
34+
Base.unsafe_convert(::Type{csrcholInfo_t}, info::SparseCholeskyInfo) = info.info
35+
36+
37+
## CuSolverParameters
38+
39+
mutable struct CuSolverParameters
40+
parameters::cusolverDnParams_t
41+
42+
function CuSolverParameters()
43+
parameters_ref = Ref{cusolverDnParams_t}()
44+
cusolverDnCreateParams(parameters_ref)
45+
obj = new(parameters_ref[])
46+
finalizer(cusolverDnDestroyParams, obj)
47+
obj
48+
end
49+
end
50+
51+
Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters
52+
53+
54+
## CuSolverIRSParameters
55+
56+
mutable struct CuSolverIRSParameters
57+
parameters::cusolverDnIRSParams_t
58+
59+
function CuSolverIRSParameters()
60+
parameters_ref = Ref{cusolverDnIRSParams_t}()
61+
cusolverDnIRSParamsCreate(parameters_ref)
62+
obj = new(parameters_ref[])
63+
finalizer(cusolverDnIRSParamsDestroy, obj)
64+
obj
65+
end
66+
end
67+
68+
Base.unsafe_convert(::Type{cusolverDnIRSParams_t}, params::CuSolverIRSParameters) = params.parameters
69+
70+
function get_info(params::CuSolverIRSParameters, field::Symbol)
71+
if field == :maxiters
72+
maxiters = Ref{Cint}()
73+
cusolverDnIRSParamsGetMaxIters(params, maxiters)
74+
return maxiters[]
75+
else
76+
error("The information $field is incorrect.")
77+
end
78+
end
79+
80+
81+
## CuSolverIRSInformation
82+
83+
mutable struct CuSolverIRSInformation
84+
information::cusolverDnIRSInfos_t
85+
86+
function CuSolverIRSInformation()
87+
info_ref = Ref{cusolverDnIRSInfos_t}()
88+
cusolverDnIRSInfosCreate(info_ref)
89+
obj = new(info_ref[])
90+
finalizer(cusolverDnIRSInfosDestroy, obj)
91+
obj
92+
end
93+
end
94+
95+
Base.unsafe_convert(::Type{cusolverDnIRSInfos_t}, info::CuSolverIRSInformation) = info.information
96+
97+
function get_info(info::CuSolverIRSInformation, field::Symbol)
98+
if field == :niters
99+
niters = Ref{Cint}()
100+
cusolverDnIRSInfosGetNiters(info, niters)
101+
return niters[]
102+
elseif field == :outer_niters
103+
outer_niters = Ref{Cint}()
104+
cusolverDnIRSInfosGetOuterNiters(info, outer_niters)
105+
return outer_niters[]
106+
# elseif field == :residual_history
107+
# residual_history = Ref{Ptr{Cvoid}
108+
# cusolverDnIRSInfosGetResidualHistory(info, residual_history)
109+
# return residual_history[]
110+
elseif field == :maxiters
111+
maxiters = Ref{Cint}()
112+
cusolverDnIRSInfosGetMaxIters(info, maxiters)
113+
return maxiters[]
114+
else
115+
error("The information $field is incorrect.")
116+
end
117+
end
118+
119+
120+
## MatrixDescriptor
121+
122+
mutable struct MatrixDescriptor
123+
desc::cudaLibMgMatrixDesc_t
124+
125+
function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) )
126+
desc = Ref{cudaLibMgMatrixDesc_t}()
127+
cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid)
128+
return new(desc[])
129+
end
130+
end
131+
132+
Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc
133+
134+
135+
## DeviceGrid
136+
137+
mutable struct DeviceGrid
138+
desc::cudaLibMgGrid_t
139+
140+
function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping)
141+
@assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1."
142+
desc = Ref{cudaLibMgGrid_t}()
143+
cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping)
144+
return new(desc[])
145+
end
146+
end
147+
148+
Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc

lib/cusolver/multigpu.jl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,6 @@
77
# NOTE: in the cublasMg preview, which also relies on this functionality, a separate library
88
# called 'cudalibmg' is introduced. factor this out when we actually ship that.
99

10-
mutable struct MatrixDescriptor
11-
desc::cudaLibMgMatrixDesc_t
12-
13-
function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) )
14-
desc = Ref{cudaLibMgMatrixDesc_t}()
15-
cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid)
16-
return new(desc[])
17-
end
18-
end
19-
20-
Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc
21-
22-
mutable struct DeviceGrid
23-
desc::cudaLibMgGrid_t
24-
25-
function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping)
26-
@assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1."
27-
desc = Ref{cudaLibMgGrid_t}()
28-
cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping)
29-
return new(desc[])
30-
end
31-
end
32-
33-
Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc
34-
3510
function allocateBuffers(n_row_devs, n_col_devs, mat::Matrix)
3611
mat_row_block_size = div(size(mat, 1), n_row_devs)
3712
mat_col_block_size = div(size(mat, 2), n_col_devs)

0 commit comments

Comments
 (0)