Skip to content

Commit 08f1c00

Browse files
committed
[CUSOLVER] Interface larft!
1 parent f5100a1 commit 08f1c00

File tree

5 files changed

+77
-4
lines changed

5 files changed

+77
-4
lines changed

lib/cusolver/base.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,23 @@ function Base.convert(::Type{cusolverEigRange_t}, range::Char)
4343
throw(ArgumentError("Unknown eigenvalue solver range $range."))
4444
end
4545
end
46+
47+
function Base.convert(::Type{cusolverStorevMode_t}, storev::Char)
48+
if storev == 'C'
49+
CUBLAS_STOREV_COLUMNWISE
50+
elseif storev == 'R'
51+
CUBLAS_STOREV_ROWWISE
52+
else
53+
throw(ArgumentError("Unknown storage mode $storev."))
54+
end
55+
end
56+
57+
function Base.convert(::Type{cusolverDirectMode_t}, direct::Char)
58+
if direct == 'F'
59+
CUBLAS_DIRECT_FORWARD
60+
elseif direct == 'B'
61+
CUBLAS_DIRECT_BACKWARD
62+
else
63+
throw(ArgumentError("Unknown direction mode $direct."))
64+
end
65+
end

lib/cusolver/dense_generic.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,33 @@ function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasF
192192
A
193193
end
194194

195+
# Xlarft!
196+
function larft!(direct::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t::StridedCuMatrix{T}) where {T <: BlasFloat}
197+
n, k = size(v)
198+
ktau = length(tau)
199+
mt, nt = size(t)
200+
(n < k) && throw(ArgumentError("The number of elementary reflectors ($k) must be lower or equal to the order of block reflector H ($n)."))
201+
(ktau != k) && throw(ArgumentError("The length of tau ($ktau) is not equal to the number of elementary reflectors ($k)."))
202+
(mt != k || nt != k) && throw(ArgumentError("The size of the triangular factor of the block reflector is ($mt, $nt) and must be ($k, $k)."))
203+
ldv = max(1, stride(v, 2))
204+
ldt = max(1, stride(t, 2))
205+
params = CuSolverParameters()
206+
207+
function bufferSize()
208+
out_cpu = Ref{Csize_t}(0)
209+
out_gpu = Ref{Csize_t}(0)
210+
cusolverDnXlarft_bufferSize(dense_handle(), params, direct, 'C', n, k, T,
211+
v, ldv, T, tau, T, t, ldt, T, out_gpu, out_cpu)
212+
out_gpu[], out_cpu[]
213+
end
214+
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
215+
cusolverDnXlarft(dense_handle(), params, direct, 'C', n, k, T, v, ldv, T, tau, T, t,
216+
ldt, T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu))
217+
end
218+
219+
t
220+
end
221+
195222
# Xgesvd
196223
function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
197224
m, n = size(A)

lib/cusolver/libcusolver.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4890,8 +4890,8 @@ end
48904890
dataTypeTau::cudaDataType,
48914891
d_tau::CuPtr{Cvoid}, dataTypeT::cudaDataType,
48924892
d_T::CuPtr{Cvoid}, ldt::Int64,
4893-
computeType::CuPtr{Cvoid},
4894-
bufferOnDevice::Ptr{Cvoid},
4893+
computeType::cudaDataType,
4894+
bufferOnDevice::CuPtr{Cvoid},
48954895
workspaceInBytesOnDevice::Csize_t,
48964896
bufferOnHost::Ptr{Cvoid},
48974897
workspaceInBytesOnHost::Csize_t)::cusolverStatus_t

res/wrap/cusolver.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2338,4 +2338,4 @@ needs_context = false
23382338
8 = "CuPtr{Cvoid}"
23392339
11 = "CuPtr{Cvoid}"
23402340
13 = "CuPtr{Cvoid}"
2341-
15 = "CuPtr{Cvoid}"
2341+
16 = "CuPtr{Cvoid}"

test/libraries/cusolver/dense_generic.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,33 @@ m = 15
55
n = 10
66
p = 5
77

8-
@testset "cusolver -- generic API -- $elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
8+
@testset "cusolver -- generic API -- $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
9+
@testset "larft!" begin
10+
@testset "direct = $direct" for direct in ('F', 'B')
11+
direct == 'B' && continue
12+
A = rand(elty,m,n)
13+
t = rand(elty,n,n)
14+
15+
dA = CuMatrix(A)
16+
dA, dτ = CUSOLVER.geqrf!(dA)
17+
hI = Matrix{elty}(I, m, m)
18+
dI = CuArray(hI)
19+
dH = CUSOLVER.ormqr!('L', 'N', dA, dτ, copy(dI))
20+
21+
v = Array(dA)
22+
for j = 1:n
23+
v[j,j] = one(elty)
24+
for i = 1:j-1
25+
v[i,j] = zero(elty)
26+
end
27+
end
28+
dv = CuArray(v)
29+
dt = CuMatrix(t)
30+
dt = CUSOLVER.larft!(direct, dv, dτ, dt)
31+
@test dI - dv * dt * dv' dH
32+
end
33+
end
34+
935
@testset "sytrs!" begin
1036
for uplo in ('L', 'U')
1137
A = rand(elty,n,n)

0 commit comments

Comments
 (0)