Skip to content

Commit 09ef437

Browse files
committed
Add an argument storev in the function lartf!
1 parent 08f1c00 commit 09ef437

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,11 @@ function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasF
193193
end
194194

195195
# Xlarft!
196-
function larft!(direct::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t::StridedCuMatrix{T}) where {T <: BlasFloat}
196+
function larft!(direct::Char, storev::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t::StridedCuMatrix{T}) where {T <: BlasFloat}
197197
n, k = size(v)
198198
ktau = length(tau)
199199
mt, nt = size(t)
200+
(storev != 'C') && throw(ArgumentError("Only storev = 'C' is supported."))
200201
(n < k) && throw(ArgumentError("The number of elementary reflectors ($k) must be lower or equal to the order of block reflector H ($n)."))
201202
(ktau != k) && throw(ArgumentError("The length of tau ($ktau) is not equal to the number of elementary reflectors ($k)."))
202203
(mt != k || nt != k) && throw(ArgumentError("The size of the triangular factor of the block reflector is ($mt, $nt) and must be ($k, $k)."))
@@ -207,12 +208,12 @@ function larft!(direct::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t:
207208
function bufferSize()
208209
out_cpu = Ref{Csize_t}(0)
209210
out_gpu = Ref{Csize_t}(0)
210-
cusolverDnXlarft_bufferSize(dense_handle(), params, direct, 'C', n, k, T,
211+
cusolverDnXlarft_bufferSize(dense_handle(), params, direct, storev, n, k, T,
211212
v, ldv, T, tau, T, t, ldt, T, out_gpu, out_cpu)
212213
out_gpu[], out_cpu[]
213214
end
214215
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
215-
cusolverDnXlarft(dense_handle(), params, direct, 'C', n, k, T, v, ldv, T, tau, T, t,
216+
cusolverDnXlarft(dense_handle(), params, direct, storev, n, k, T, v, ldv, T, tau, T, t,
216217
ldt, T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu))
217218
end
218219

test/libraries/cusolver/dense_generic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ p = 5
2727
end
2828
dv = CuArray(v)
2929
dt = CuMatrix(t)
30-
dt = CUSOLVER.larft!(direct, dv, dτ, dt)
30+
dt = CUSOLVER.larft!(direct, 'C', dv, dτ, dt)
3131
@test dI - dv * dt * dv' dH
3232
end
3333
end

0 commit comments

Comments
 (0)